diff --git a/compiler/src/dotty/tools/dotc/Run.scala b/compiler/src/dotty/tools/dotc/Run.scala index 3e7bba86dcf4..40a343fb1267 100644 --- a/compiler/src/dotty/tools/dotc/Run.scala +++ b/compiler/src/dotty/tools/dotc/Run.scala @@ -12,7 +12,9 @@ import typer.Typer import typer.ImportInfo.withRootImports import Decorators._ import io.AbstractFile -import Phases.unfusedPhases +import Phases.{unfusedPhases, Phase} + +import sbt.interfaces.ProgressCallback import util._ import reporting.{Suppression, Action, Profile, ActiveProfile, NoProfile} @@ -32,6 +34,10 @@ import scala.collection.mutable import scala.util.control.NonFatal import scala.io.Codec +import Run.Progress +import scala.compiletime.uninitialized +import dotty.tools.dotc.transform.MegaPhase + /** A compiler run. Exports various methods to compile source files */ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with ConstraintRunInfo { @@ -155,7 +161,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint } /** The source files of all late entered symbols, as a set */ - private var lateFiles = mutable.Set[AbstractFile]() + private val lateFiles = mutable.Set[AbstractFile]() /** A cache for static references to packages and classes */ val staticRefs = util.EqHashMap[Name, Denotation](initialCapacity = 1024) @@ -163,6 +169,67 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint /** Actions that need to be performed at the end of the current compilation run */ private var finalizeActions = mutable.ListBuffer[() => Unit]() + private var _progress: Progress | Null = null // Set if progress reporting is enabled + + private inline def trackProgress(using Context)(inline op: Context ?=> Progress => Unit): Unit = + foldProgress(())(op) + + private inline def foldProgress[T](using Context)(inline default: T)(inline op: Context ?=> Progress => T): T = + val local = _progress + if local != null then + op(using ctx)(local) + else + default + + def didEnterUnit(unit: CompilationUnit)(using Context): Boolean = + foldProgress(true /* should progress by default */)(_.tryEnterUnit(unit)) + + def canProgress()(using Context): Boolean = + foldProgress(true /* not cancelled by default */)(p => !p.checkCancellation()) + + def doAdvanceUnit()(using Context): Unit = + trackProgress: progress => + progress.currentUnitCount += 1 // trace that we completed a unit in the current (sub)phase + progress.refreshProgress() + + def doAdvanceLate()(using Context): Unit = + trackProgress: progress => + progress.currentLateUnitCount += 1 // trace that we completed a late compilation + progress.refreshProgress() + + private def doEnterPhase(currentPhase: Phase)(using Context): Unit = + trackProgress: progress => + progress.enterPhase(currentPhase) + + /** interrupt the thread and set cancellation state */ + private def cancelInterrupted(): Unit = + try + trackProgress(_.cancel()) + finally + Thread.currentThread().nn.interrupt() + + private def doAdvancePhase(currentPhase: Phase, wasRan: Boolean)(using Context): Unit = + trackProgress: progress => + progress.currentUnitCount = 0 // reset unit count in current (sub)phase + progress.currentCompletedSubtraversalCount = 0 // reset subphase index to initial + progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase + if wasRan then + // add an extra traversal now that we completed a (sub)phase + progress.completedTraversalCount += 1 + else + // no subphases were ran, remove traversals from expected total + progress.totalTraversals -= currentPhase.traversals + + private def tryAdvanceSubPhase()(using Context): Unit = + trackProgress: progress => + if progress.canAdvanceSubPhase then + progress.currentUnitCount = 0 // reset unit count in current (sub)phase + progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase + progress.completedTraversalCount += 1 // add an extra traversal now that we completed a (sub)phase + progress.currentCompletedSubtraversalCount += 1 // record that we've seen a subphase + if !progress.isCancelled() then + progress.tickSubphase() + /** Will be set to true if any of the compiled compilation units contains * a pureFunctions language import. */ @@ -233,17 +300,20 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint if ctx.settings.YnoDoubleBindings.value then ctx.base.checkNoDoubleBindings = true - def runPhases(using Context) = { + def runPhases(allPhases: Array[Phase])(using Context) = { var lastPrintedTree: PrintedTree = NoPrintedTree val profiler = ctx.profiler var phasesWereAdjusted = false - for (phase <- ctx.base.allPhases) - if (phase.isRunnable) + for phase <- allPhases do + doEnterPhase(phase) + val phaseWillRun = phase.isRunnable + if phaseWillRun then Stats.trackTime(s"phase time ms/$phase") { val start = System.currentTimeMillis val profileBefore = profiler.beforePhase(phase) - units = phase.runOn(units) + try units = phase.runOn(units) + catch case _: InterruptedException => cancelInterrupted() profiler.afterPhase(phase, profileBefore) if (ctx.settings.Xprint.value.containsPhase(phase)) for (unit <- units) @@ -260,18 +330,25 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint if !Feature.ccEnabledSomewhere then ctx.base.unlinkPhaseAsDenotTransformer(Phases.checkCapturesPhase.prev) ctx.base.unlinkPhaseAsDenotTransformer(Phases.checkCapturesPhase) - + end if + end if + end if + doAdvancePhase(phase, wasRan = phaseWillRun) + end for profiler.finished() } val runCtx = ctx.fresh runCtx.setProfiler(Profiler()) unfusedPhases.foreach(_.initContext(runCtx)) - runPhases(using runCtx) + val fusedPhases = runCtx.base.allPhases + runCtx.withProgressCallback: cb => + _progress = Progress(cb, this, fusedPhases.map(_.traversals).sum) + runPhases(allPhases = fusedPhases)(using runCtx) if (!ctx.reporter.hasErrors) Rewrites.writeBack() suppressions.runFinished(hasErrors = ctx.reporter.hasErrors) - while (finalizeActions.nonEmpty) { + while (finalizeActions.nonEmpty && canProgress()) { val action = finalizeActions.remove(0) action() } @@ -293,10 +370,9 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint .withRootImports def process()(using Context) = - ctx.typer.lateEnterUnit(doTypeCheck => - if typeCheck then - if compiling then finalizeActions += doTypeCheck - else doTypeCheck() + ctx.typer.lateEnterUnit(typeCheck)(doTypeCheck => + if compiling then finalizeActions += doTypeCheck + else doTypeCheck() ) process()(using unitCtx) @@ -399,7 +475,129 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint } object Run { + + case class SubPhase(val name: String): + override def toString: String = name + + class SubPhases(val phase: Phase): + require(phase.exists) + + private def baseName: String = phase match + case phase: MegaPhase => phase.shortPhaseName + case phase => phase.phaseName + + val all = IArray.from(phase.subPhases.map(sub => s"$baseName[$sub]")) + + def next(using Context): Option[SubPhases] = + val next0 = phase.megaPhase.next.megaPhase + if next0.exists then Some(SubPhases(next0)) + else None + + def size: Int = all.size + + def subPhase(index: Int) = + if index < all.size then all(index) + else baseName + + + private class Progress(cb: ProgressCallback, private val run: Run, val initialTraversals: Int): + export cb.{cancel, isCancelled} + + var totalTraversals: Int = initialTraversals // track how many phases we expect to run + var currentUnitCount: Int = 0 // current unit count in the current (sub)phase + var currentLateUnitCount: Int = 0 // current late unit count + var completedTraversalCount: Int = 0 // completed traversals over all files + var currentCompletedSubtraversalCount: Int = 0 // completed subphases in the current phase + var seenPhaseCount: Int = 0 // how many phases we've seen so far + + private var currPhase: Phase = uninitialized // initialized by enterPhase + private var subPhases: SubPhases = uninitialized // initialized by enterPhase + private var currPhaseName: String = uninitialized // initialized by enterPhase + private var nextPhaseName: String = uninitialized // initialized by enterPhase + + /** Enter into a new real phase, setting the current and next (sub)phases */ + def enterPhase(newPhase: Phase)(using Context): Unit = + if newPhase ne currPhase then + currPhase = newPhase + subPhases = SubPhases(newPhase) + tickSubphase() + + def canAdvanceSubPhase: Boolean = + currentCompletedSubtraversalCount + 1 < subPhases.size + + /** Compute the current (sub)phase name and next (sub)phase name */ + def tickSubphase()(using Context): Unit = + val index = currentCompletedSubtraversalCount + val s = subPhases + currPhaseName = s.subPhase(index) + nextPhaseName = + if index + 1 < s.all.size then s.subPhase(index + 1) + else s.next match + case None => "" + case Some(next0) => next0.subPhase(0) + if seenPhaseCount > 0 then + refreshProgress() + + + /** Counts the number of completed full traversals over files, plus the number of units in the current phase */ + private def currentProgress(): Int = + completedTraversalCount * work() + currentUnitCount + currentLateUnitCount + + /**Total progress is computed as the sum of + * - the number of traversals we expect to make over all files + * - the number of late compilations + */ + private def totalProgress(): Int = + totalTraversals * work() + run.lateFiles.size + + private def work(): Int = run.files.size + + private def requireInitialized(): Unit = + require((currPhase: Phase | Null) != null, "enterPhase was not called") + + def checkCancellation(): Boolean = + if Thread.interrupted() then cancel() + isCancelled() + + /** trace that we are beginning a unit in the current (sub)phase, unless cancelled */ + def tryEnterUnit(unit: CompilationUnit): Boolean = + if checkCancellation() then false + else + requireInitialized() + cb.informUnitStarting(currPhaseName, unit) + true + + /** trace the current progress out of the total, in the current (sub)phase, reporting the next (sub)phase */ + def refreshProgress()(using Context): Unit = + requireInitialized() + val total = totalProgress() + if total > 0 && !cb.progress(currentProgress(), total, currPhaseName, nextPhaseName) then + cancel() + extension (run: Run | Null) + + /** record that the current phase has begun for the compilation unit of the current Context */ + def enterUnit(unit: CompilationUnit)(using Context): Boolean = + if run != null then run.didEnterUnit(unit) + else true // don't check cancellation if we're not tracking progress + + /** check progress cancellation, true if not cancelled */ + def enterRegion()(using Context): Boolean = + if run != null then run.canProgress() + else true // don't check cancellation if we're not tracking progress + + /** advance the unit count and record progress in the current phase */ + def advanceUnit()(using Context): Unit = + if run != null then run.doAdvanceUnit() + + /** if there exists another subphase, switch to it and record progress */ + def enterNextSubphase()(using Context): Unit = + if run != null then run.tryAdvanceSubPhase() + + /** advance the late count and record progress in the current phase */ + def advanceLate()(using Context): Unit = + if run != null then run.doAdvanceLate() + def enrichedErrorMessage: Boolean = if run == null then false else run.myEnrichedErrorMessage def enrichErrorMessage(errorMessage: String)(using Context): String = if run == null then diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 8a7f2ff4e051..20b553149edb 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -34,7 +34,7 @@ import scala.annotation.internal.sharable import DenotTransformers.DenotTransformer import dotty.tools.dotc.profile.Profiler -import dotty.tools.dotc.sbt.interfaces.IncrementalCallback +import dotty.tools.dotc.sbt.interfaces.{IncrementalCallback, ProgressCallback} import util.Property.Key import util.Store import plugins._ @@ -53,8 +53,9 @@ object Contexts { private val (notNullInfosLoc, store8) = store7.newLocation[List[NotNullInfo]]() private val (importInfoLoc, store9) = store8.newLocation[ImportInfo | Null]() private val (typeAssignerLoc, store10) = store9.newLocation[TypeAssigner](TypeAssigner) + private val (progressCallbackLoc, store11) = store10.newLocation[ProgressCallback | Null]() - private val initialStore = store10 + private val initialStore = store11 /** The current context */ inline def ctx(using ctx: Context): Context = ctx @@ -177,6 +178,14 @@ object Contexts { val local = incCallback local != null && local.enabled || forceRun + /** The Zinc compile progress callback implementation if we are run from Zinc, null otherwise */ + def progressCallback: ProgressCallback | Null = store(progressCallbackLoc) + + /** Run `op` if there exists a Zinc progress callback */ + inline def withProgressCallback(inline op: ProgressCallback => Unit): Unit = + val local = progressCallback + if local != null then op(local) + /** The current plain printer */ def printerFn: Context => Printer = store(printerFnLoc) @@ -675,6 +684,7 @@ object Contexts { def setCompilerCallback(callback: CompilerCallback): this.type = updateStore(compilerCallbackLoc, callback) def setIncCallback(callback: IncrementalCallback): this.type = updateStore(incCallbackLoc, callback) + def setProgressCallback(callback: ProgressCallback): this.type = updateStore(progressCallbackLoc, callback) def setPrinterFn(printer: Context => Printer): this.type = updateStore(printerFnLoc, printer) def setSettings(settingsState: SettingsState): this.type = updateStore(settingsStateLoc, settingsState) def setRun(run: Run | Null): this.type = updateStore(runLoc, run) diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index 2a3828004525..d6a49186b539 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -317,19 +317,29 @@ object Phases { /** List of names of phases that should precede this phase */ def runsAfter: Set[String] = Set.empty + /** for purposes of progress tracking, overridden in TyperPhase */ + def subPhases: List[Run.SubPhase] = Nil + final def traversals: Int = if subPhases.isEmpty then 1 else subPhases.length + /** @pre `isRunnable` returns true */ def run(using Context): Unit /** @pre `isRunnable` returns true */ def runOn(units: List[CompilationUnit])(using runCtx: Context): List[CompilationUnit] = - units.map { unit => + val buf = List.newBuilder[CompilationUnit] + for unit <- units do given unitCtx: Context = runCtx.fresh.setPhase(this.start).setCompilationUnit(unit).withRootImports - try run - catch case ex: Throwable if !ctx.run.enrichedErrorMessage => - println(ctx.run.enrichErrorMessage(s"unhandled exception while running $phaseName on $unit")) - throw ex - unitCtx.compilationUnit - } + if ctx.run.enterUnit(unit) then + try run + catch case ex: Throwable if !ctx.run.enrichedErrorMessage => + println(ctx.run.enrichErrorMessage(s"unhandled exception while running $phaseName on $unit")) + throw ex + finally ctx.run.advanceUnit() + buf += unitCtx.compilationUnit + end if + end for + buf.result() + end runOn /** Convert a compilation unit's tree to a string; can be overridden */ def show(tree: untpd.Tree)(using Context): String = @@ -436,12 +446,33 @@ object Phases { final def iterator: Iterator[Phase] = Iterator.iterate(this)(_.next) takeWhile (_.hasNext) - final def monitor(doing: String)(body: => Unit)(using Context): Unit = - try body - catch - case NonFatal(ex) if !ctx.run.enrichedErrorMessage => - report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}")) + /** Cancellable region, if not cancelled, run the body in the context of the current compilation unit. + * Enrich crash messages. + */ + final def monitor(doing: String)(body: Context ?=> Unit)(using Context): Boolean = + val unit = ctx.compilationUnit + if ctx.run.enterUnit(unit) then + try {body; true} + catch case NonFatal(ex) if !ctx.run.enrichedErrorMessage => + report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing $unit")) throw ex + finally ctx.run.advanceUnit() + else + false + + inline def runSubPhase[T](id: Run.SubPhase)(inline body: (Run.SubPhase, Context) ?=> T)(using Context): T = + given Run.SubPhase = id + try + body + finally + ctx.run.enterNextSubphase() + + /** Do not run if compile progress has been cancelled */ + final def cancellable(body: Context ?=> Unit)(using Context): Boolean = + if ctx.run.enterRegion() then + {body; true} + else + false override def toString: String = phaseName } diff --git a/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala b/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala index 86ae99b3e0f9..455b6c89a0ba 100644 --- a/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala +++ b/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala @@ -12,7 +12,6 @@ import NameOps._ import ast.Trees.Tree import Phases.Phase - /** Load trees from TASTY files */ class ReadTasty extends Phase { @@ -22,7 +21,15 @@ class ReadTasty extends Phase { ctx.settings.fromTasty.value override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = - withMode(Mode.ReadPositions)(units.flatMap(readTASTY(_))) + withMode(Mode.ReadPositions) { + val nextUnits = collection.mutable.ListBuffer.empty[CompilationUnit] + val unitContexts = units.view.map(ctx.fresh.setCompilationUnit) + for unitContext <- unitContexts if addTasty(nextUnits += _)(using unitContext) do () + nextUnits.toList + } + + def addTasty(fn: CompilationUnit => Unit)(using Context): Boolean = monitor(phaseName): + readTASTY(ctx.compilationUnit).foreach(fn) def readTASTY(unit: CompilationUnit)(using Context): Option[CompilationUnit] = unit match { case unit: TASTYCompilationUnit => @@ -77,7 +84,7 @@ class ReadTasty extends Phase { } } case unit => - Some(unit) + Some(unit) } def run(using Context): Unit = unsupported("run") diff --git a/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala b/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala index 7caff4996b85..bcabfbd03a1d 100644 --- a/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala +++ b/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala @@ -22,7 +22,7 @@ class Parser extends Phase { */ private[dotc] var firstXmlPos: SourcePosition = NoSourcePosition - def parse(using Context) = monitor("parser") { + def parse(using Context): Boolean = monitor("parser") { val unit = ctx.compilationUnit unit.untpdTree = if (unit.isJava) new JavaParsers.JavaParser(unit.source).parse() @@ -46,10 +46,15 @@ class Parser extends Phase { report.inform(s"parsing ${unit.source}") ctx.fresh.setCompilationUnit(unit).withRootImports - unitContexts.foreach(parse(using _)) + val unitContexts0 = + for + unitContext <- unitContexts + if parse(using unitContext) + yield unitContext + record("parsedTrees", ast.Trees.ntrees) - unitContexts.map(_.compilationUnit) + unitContexts0.map(_.compilationUnit) } def run(using Context): Unit = unsupported("run") diff --git a/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java b/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java new file mode 100644 index 000000000000..39f5ca39962b --- /dev/null +++ b/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java @@ -0,0 +1,21 @@ +package dotty.tools.dotc.sbt.interfaces; + +import dotty.tools.dotc.CompilationUnit; + +public interface ProgressCallback { + /** Record that the cancellation signal has been received during the Zinc run. */ + default void cancel() {} + + /** Report on if there was a cancellation signal for the current Zinc run. */ + default boolean isCancelled() { return false; } + + /** Record that a unit has started compiling in the given phase. */ + default void informUnitStarting(String phase, CompilationUnit unit) {} + + /** Record the current compilation progress. + * @param current `completedPhaseCount * totalUnits + completedUnitsInCurrPhase + completedLate` + * @param total `totalPhases * totalUnits + totalLate` + * @return true if the compilation should continue (callers are expected to cancel if this returns false) + */ + default boolean progress(int current, int total, String currPhase, String nextPhase) { return true; } +} diff --git a/compiler/src/dotty/tools/dotc/semanticdb/ExtractSemanticDB.scala b/compiler/src/dotty/tools/dotc/semanticdb/ExtractSemanticDB.scala index f1b4e4637eb8..07f3fcea2e88 100644 --- a/compiler/src/dotty/tools/dotc/semanticdb/ExtractSemanticDB.scala +++ b/compiler/src/dotty/tools/dotc/semanticdb/ExtractSemanticDB.scala @@ -25,10 +25,12 @@ import scala.jdk.CollectionConverters._ import scala.PartialFunction.condOpt import typer.ImportInfo.withRootImports +import dotty.tools.dotc.reporting.Diagnostic.Warning import dotty.tools.dotc.{semanticdb => s} import dotty.tools.io.{AbstractFile, JarArchive} import dotty.tools.dotc.semanticdb.DiagnosticOps.* import scala.util.{Using, Failure, Success} +import java.nio.file.Path /** Extract symbol references and uses to semanticdb files. @@ -60,48 +62,64 @@ class ExtractSemanticDB private (phaseMode: ExtractSemanticDB.PhaseMode) extends // Check not needed since it does not transform trees override def isCheckable: Boolean = false + private def computeDiagnostics( + sourceRoot: String, + warnings: Map[SourceFile, List[Warning]], + append: ((Path, List[Diagnostic])) => Unit)(using Context): Boolean = monitor(phaseName) { + val unit = ctx.compilationUnit + warnings.get(unit.source).foreach { ws => + val outputDir = + ExtractSemanticDB.semanticdbPath( + unit.source, + ExtractSemanticDB.semanticdbOutDir, + sourceRoot + ) + append((outputDir, ws.map(_.toSemanticDiagnostic))) + } + } + + private def extractSemanticDB(sourceRoot: String, writeSemanticdbText: Boolean)(using Context): Boolean = + monitor(phaseName) { + val unit = ctx.compilationUnit + val outputDir = + ExtractSemanticDB.semanticdbPath( + unit.source, + ExtractSemanticDB.semanticdbOutDir, + sourceRoot + ) + val extractor = ExtractSemanticDB.Extractor() + extractor.extract(unit.tpdTree) + ExtractSemanticDB.write( + unit.source, + extractor.occurrences.toList, + extractor.symbolInfos.toList, + extractor.synthetics.toList, + outputDir, + sourceRoot, + writeSemanticdbText + ) + } + override def runOn(units: List[CompilationUnit])(using ctx: Context): List[CompilationUnit] = { val sourceRoot = ctx.settings.sourceroot.value val appendDiagnostics = phaseMode == ExtractSemanticDB.PhaseMode.AppendDiagnostics + val unitContexts = units.map(ctx.fresh.setCompilationUnit(_).withRootImports) if (appendDiagnostics) val warnings = ctx.reporter.allWarnings.groupBy(w => w.pos.source) - units.flatMap { unit => - warnings.get(unit.source).map { ws => - val unitCtx = ctx.fresh.setCompilationUnit(unit).withRootImports - val outputDir = - ExtractSemanticDB.semanticdbPath( - unit.source, - ExtractSemanticDB.semanticdbOutDir(using unitCtx), - sourceRoot - ) - (outputDir, ws.map(_.toSemanticDiagnostic)) + val buf = mutable.ListBuffer.empty[(Path, Seq[Diagnostic])] + val units0 = + for unitCtx <- unitContexts if computeDiagnostics(sourceRoot, warnings, buf += _)(using unitCtx) + yield unitCtx.compilationUnit + cancellable { + buf.toList.asJava.parallelStream().forEach { case (out, warnings) => + ExtractSemanticDB.appendDiagnostics(warnings, out) } - }.asJava.parallelStream().forEach { case (out, warnings) => - ExtractSemanticDB.appendDiagnostics(warnings, out) } + units0 else val writeSemanticdbText = ctx.settings.semanticdbText.value - units.foreach { unit => - val unitCtx = ctx.fresh.setCompilationUnit(unit).withRootImports - val outputDir = - ExtractSemanticDB.semanticdbPath( - unit.source, - ExtractSemanticDB.semanticdbOutDir(using unitCtx), - sourceRoot - ) - val extractor = ExtractSemanticDB.Extractor() - extractor.extract(unit.tpdTree)(using unitCtx) - ExtractSemanticDB.write( - unit.source, - extractor.occurrences.toList, - extractor.symbolInfos.toList, - extractor.synthetics.toList, - outputDir, - sourceRoot, - writeSemanticdbText - ) - } - units + for unitCtx <- unitContexts if extractSemanticDB(sourceRoot, writeSemanticdbText)(using unitCtx) + yield unitCtx.compilationUnit } def run(using Context): Unit = unsupported("run") @@ -611,4 +629,4 @@ object ExtractSemanticDB: traverse(vparam.tpt) tparams.foreach(tp => traverse(tp.rhs)) end Extractor -end ExtractSemanticDB \ No newline at end of file +end ExtractSemanticDB diff --git a/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala b/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala index 58c3cd7c65ed..fe70a1659036 100644 --- a/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala +++ b/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala @@ -145,6 +145,12 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase { if (miniPhases.length == 1) miniPhases(0).phaseName else miniPhases.map(_.phaseName).mkString("MegaPhase{", ", ", "}") + /** Used in progress reporting to avoid super long phase names, also the precision is not so important here */ + lazy val shortPhaseName: String = + if (miniPhases.length == 1) miniPhases(0).phaseName + else + s"MegaPhase{${miniPhases.head.phaseName},...,${miniPhases.last.phaseName}}" + private var relaxedTypingCache: Boolean = _ private var relaxedTypingKnown = false diff --git a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala index 1efb3c88149e..523a82dcd947 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala @@ -18,6 +18,7 @@ import Phases._ import scala.collection.mutable import Semantic._ +import dotty.tools.unsupported class Checker extends Phase: @@ -30,19 +31,28 @@ class Checker extends Phase: override def isEnabled(using Context): Boolean = super.isEnabled && ctx.settings.YcheckInit.value + def traverse(traverser: InitTreeTraverser)(using Context): Boolean = monitor(phaseName): + val unit = ctx.compilationUnit + traverser.traverse(unit.tpdTree) + override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = val checkCtx = ctx.fresh.setPhase(this.start) val traverser = new InitTreeTraverser() - units.foreach { unit => traverser.traverse(unit.tpdTree) } - val classes = traverser.getClasses() + val unitContexts = units.map(unit => checkCtx.fresh.setCompilationUnit(unit)) + + val units0 = + for unitContext <- unitContexts if traverse(traverser)(using unitContext) yield unitContext.compilationUnit + + cancellable { + val classes = traverser.getClasses() - Semantic.checkClasses(classes)(using checkCtx) + Semantic.checkClasses(classes)(using checkCtx) + } - units + units0 + end runOn - def run(using Context): Unit = - // ignore, we already called `Semantic.check()` in `runOn` - () + def run(using Context): Unit = unsupported("run") class InitTreeTraverser extends TreeTraverser: private val classes: mutable.ArrayBuffer[ClassSymbol] = new mutable.ArrayBuffer diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 36ffbd2e64a4..cbc2796c3895 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -722,20 +722,27 @@ class Namer { typer: Typer => * Will call the callback with an implementation of type checking * That will set the tpdTree and root tree for the compilation unit. */ - def lateEnterUnit(typeCheckCB: (() => Unit) => Unit)(using Context) = + def lateEnterUnit(typeCheck: Boolean)(typeCheckCB: (() => Unit) => Unit)(using Context) = val unit = ctx.compilationUnit /** Index symbols in unit.untpdTree with lateCompile flag = true */ def lateEnter()(using Context): Context = val saved = lateCompile lateCompile = true - try index(unit.untpdTree :: Nil) finally lateCompile = saved + try + index(unit.untpdTree :: Nil) + finally + lateCompile = saved + if !typeCheck then ctx.run.advanceLate() /** Set the tpdTree and root tree of the compilation unit */ def lateTypeCheck()(using Context) = - unit.tpdTree = typer.typedExpr(unit.untpdTree) - val phase = new transform.SetRootTree() - phase.run + try + unit.tpdTree = typer.typedExpr(unit.untpdTree) + val phase = new transform.SetRootTree() + phase.run + finally + if typeCheck then ctx.run.advanceLate() unit.untpdTree = if (unit.isJava) new JavaParser(unit.source).parse() @@ -746,9 +753,10 @@ class Namer { typer: Typer => // inline body annotations are set in namer, capturing the current context // we need to prepare the context for inlining. lateEnter() - typeCheckCB { () => - lateTypeCheck() - } + if typeCheck then + typeCheckCB { () => + lateTypeCheck() + } } } end lateEnterUnit diff --git a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala index f0218413d6ab..857ed1bad4d9 100644 --- a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala +++ b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala @@ -3,6 +3,7 @@ package dotc package typer import core._ +import Run.SubPhase import Phases._ import Contexts._ import Symbols._ @@ -31,13 +32,13 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { // Run regardless of parsing errors override def isRunnable(implicit ctx: Context): Boolean = true - def enterSyms(using Context): Unit = monitor("indexing") { + def enterSyms(using Context)(using subphase: SubPhase): Boolean = monitor(subphase.name) { val unit = ctx.compilationUnit ctx.typer.index(unit.untpdTree) typr.println("entered: " + unit.source) } - def typeCheck(using Context): Unit = monitor("typechecking") { + def typeCheck(using Context)(using subphase: SubPhase): Boolean = monitor(subphase.name) { val unit = ctx.compilationUnit try if !unit.suspended then @@ -49,7 +50,7 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { catch case _: CompilationUnit.SuspendException => () } - def javaCheck(using Context): Unit = monitor("checking java") { + def javaCheck(using Context)(using subphase: SubPhase): Boolean = monitor(subphase.name) { val unit = ctx.compilationUnit if unit.isJava then JavaChecks.check(unit.tpdTree) @@ -58,7 +59,11 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { protected def discardAfterTyper(unit: CompilationUnit)(using Context): Boolean = unit.isJava || unit.suspended + override val subPhases: List[SubPhase] = List( + SubPhase("indexing"), SubPhase("typechecking"), SubPhase("checkingJava")) + override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = + val List(Indexing @ _, Typechecking @ _, CheckingJava @ _) = subPhases: @unchecked val unitContexts = for unit <- units yield val newCtx0 = ctx.fresh.setPhase(this.start).setCompilationUnit(unit) @@ -69,7 +74,12 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { else newCtx - unitContexts.foreach(enterSyms(using _)) + val unitContexts0 = runSubPhase(Indexing) { + for + unitContext <- unitContexts + if enterSyms(using unitContext) + yield unitContext + } ctx.base.parserPhase match { case p: ParserPhase => @@ -81,11 +91,22 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { case _ => } - unitContexts.foreach(typeCheck(using _)) + val unitContexts1 = runSubPhase(Typechecking) { + for + unitContext <- unitContexts0 + if typeCheck(using unitContext) + yield unitContext + } + record("total trees after typer", ast.Trees.ntrees) - unitContexts.foreach(javaCheck(using _)) // after typechecking to avoid cycles - val newUnits = unitContexts.map(_.compilationUnit).filterNot(discardAfterTyper) + val unitContexts2 = runSubPhase(CheckingJava) { + for + unitContext <- unitContexts1 + if javaCheck(using unitContext) // after typechecking to avoid cycles + yield unitContext + } + val newUnits = unitContexts2.map(_.compilationUnit).filterNot(discardAfterTyper) ctx.run.nn.checkSuspendedUnits(newUnits) newUnits diff --git a/compiler/test/dotty/tools/DottyTest.scala b/compiler/test/dotty/tools/DottyTest.scala index 54cf0e0c177c..7ccbc09a4c92 100644 --- a/compiler/test/dotty/tools/DottyTest.scala +++ b/compiler/test/dotty/tools/DottyTest.scala @@ -44,9 +44,14 @@ trait DottyTest extends ContextEscapeDetection { fc.setProperty(ContextDoc, new ContextDocstrings) } + protected def defaultCompiler: Compiler = new Compiler() + private def compilerWithChecker(phase: String)(assertion: (tpd.Tree, Context) => Unit) = new Compiler { + + private val baseCompiler = defaultCompiler + override def phases = { - val allPhases = super.phases + val allPhases = baseCompiler.phases val targetPhase = allPhases.flatten.find(p => p.phaseName == phase).get val groupsBefore = allPhases.takeWhile(x => !x.contains(targetPhase)) val lastGroup = allPhases.find(x => x.contains(targetPhase)).get.takeWhile(x => !(x eq targetPhase)) @@ -67,6 +72,15 @@ trait DottyTest extends ContextEscapeDetection { run.runContext } + def checkAfterCompile(checkAfterPhase: String, sources: List[String])(assertion: Context => Unit): Context = { + val c = defaultCompiler + val run = c.newRun + run.compileFromStrings(sources) + val rctx = run.runContext + assertion(rctx) + rctx + } + def checkTypes(source: String, typeStrings: String*)(assertion: (List[Type], Context) => Unit): Unit = checkTypes(source, List(typeStrings.toList)) { (tpess, ctx) => (tpess: @unchecked) match { case List(tpes) => assertion(tpes, ctx) diff --git a/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala new file mode 100644 index 000000000000..489dc0f1759c --- /dev/null +++ b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala @@ -0,0 +1,283 @@ +package dotty.tools.dotc.sbt + +import dotty.tools.DottyTest +import dotty.tools.dotc.core.Contexts.FreshContext +import dotty.tools.dotc.sbt.ProgressCallbackTest.* + +import org.junit.Assert.* +import org.junit.Test + +import dotty.tools.toOption +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Contexts.ctx +import dotty.tools.dotc.CompilationUnit +import dotty.tools.dotc.Compiler +import dotty.tools.dotc.Run +import dotty.tools.dotc.core.Phases.Phase +import dotty.tools.io.VirtualDirectory +import dotty.tools.dotc.NoCompilationUnit +import dotty.tools.dotc.interactive.Interactive.Include.all + +final class ProgressCallbackTest extends DottyTest: + + @Test + def testCallback: Unit = + val source1 = """class Foo""" + val source2 = """class Bar""" + + inspectProgress(List(source1, source2), terminalPhase = None): progressCallback => + locally: + // (1) assert that the way we compute next phase in `Run.doAdvancePhase` is correct + assertNextPhaseIsNext() + + locally: + // (1) given correct computation, check that the recorded progression of phases is monotonic + assertMonotonicProgression(progressCallback) + + locally: + // (1) given monotonic progression, check that the recorded progression of phases is complete + val expectedCurr = allSubPhases + val expectedNext = expectedCurr.tail ++ syntheticNextPhases + assertProgressPhases(progressCallback, expectedCurr, expectedNext) + + locally: + // (2) next check that for each unit, we record all the "runnable" phases that could go through + assertExpectedPhasesForUnits(progressCallback, expectedPhases = runnableSubPhases) + + locally: + // (2) therefore we can now cross-reference the recorded progression with the recorded phases per unit + assertTotalUnits(progressCallback) + + locally: + // (3) finally, check that the callback was not cancelled + assertFalse(progressCallback.isCancelled) + end testCallback + + // TODO: test cancellation + + @Test + def cancelMidTyper: Unit = + inspectCancellationAtPhase("typer[typechecking]") + + @Test + def cancelErasure: Unit = + inspectCancellationAtPhase("erasure") + + @Test + def cancelPickler: Unit = + inspectCancellationAtPhase("pickler") + + def cancelOnEnter(targetPhase: String)(testCallback: TestProgressCallback): Boolean = + testCallback.latestProgress.exists(_.currPhase == targetPhase) + + def inspectCancellationAtPhase(targetPhase: String): Unit = + val source1 = """class Foo""" + + inspectProgress(List(source1), cancellation = Some(cancelOnEnter(targetPhase))): progressCallback => + locally: + // (1) assert that the compiler was cancelled + assertTrue("should have cancelled", progressCallback.isCancelled) + + locally: + // (2) assert that compiler visited all the subphases before cancellation, + // and does not visit any after. + // (2.2) first extract the surrounding phases of the target + val (befores, target +: next +: _) = allSubPhases.span(_ != targetPhase): @unchecked + // (2.3) we expect to see the subphases before&including target reported as a "current" phase, so extract here + val expectedCurr = befores :+ target + // (2.4) we expect to see next after target reported as a "next" phase, so extract here + val expectedNext = expectedCurr.tail :+ next + assertProgressPhases(progressCallback, expectedCurr, expectedNext) + + locally: + // (3) assert that the compilation units were only entered in the phases before cancellation + val (befores, target +: next +: _) = runnableSubPhases.span(_ != targetPhase): @unchecked + assertExpectedPhasesForUnits(progressCallback, expectedPhases = befores) + + locally: + // (4) assert that the final progress recorded is at the target phase, + // and progress is equal to the number of phases before the target. + val (befores, target +: next +: _) = runnableSubPhases.span(_ != targetPhase): @unchecked + // (4.1) we expect cancellation to occur *as we enter* the target phase, + // so no units should be visited in this phase. Therefore progress + // should be equal to the number of phases before the target. (as we have 1 unit) + val expectedProgress = befores.size + progressCallback.latestProgress match + case Some(ProgressEvent(`expectedProgress`, _, `target`, `next`)) => + case other => fail(s"did not match expected progress, found $other") + end inspectCancellationAtPhase + + /** Assert that the computed `next` phase matches the real next phase */ + def assertNextPhaseIsNext()(using Context): Unit = + val allPhases = ctx.base.allPhases + for case Array(p1, p2) <- allPhases.sliding(2) do + val p1Next = Run.SubPhases(p1).next.get.phase // used to compute the next phase in `Run.doAdvancePhase` + assertEquals(p1Next, p2) + + /** Assert that the recorded progression of phases are all in the real progression, and that order is preserved */ + def assertMonotonicProgression(progressCallback: TestProgressCallback)(using Context): Unit = + val allPhasePlan = ctx.base.allPhases.flatMap(asSubphases) ++ syntheticNextPhases + for case List( + PhaseTransition(curr1, next1), + PhaseTransition(curr2, next2) + ) <- progressCallback.progressPhasesFinal.sliding(2) do + val curr1Index = indexOrFail(allPhasePlan, curr1) + val curr2Index = indexOrFail(allPhasePlan, curr2) + val next1Index = indexOrFail(allPhasePlan, next1) + val next2Index = indexOrFail(allPhasePlan, next2) + assertTrue(s"Phase `$curr1` did not come before `$curr2`", curr1Index < curr2Index) + assertTrue(s"Phase `$next1` did not come before `$next2`", next1Index < next2Index) + assertTrue(s"Phase `$curr1` did not come before `$next1`", curr1Index < next1Index) + assertTrue(s"Phase `$curr2` did not come before `$next2`", curr2Index < next2Index) + assertTrue(s"Predicted next phase `$next1` didn't match the following current `$curr2`", next1Index == curr2Index) + + /** Assert that the recorded progression of phases contains every phase in the plan */ + def assertProgressPhases(progressCallback: TestProgressCallback, + currExpected: Seq[String], nextExpected: Seq[String])(using Context): Unit = + val (allPhasePlan, expectedCurrPhases, expectedNextPhases) = + val allPhases = currExpected + val firstPhase = allPhases.head + val expectedCurrPhases = allPhases.toSet + val expectedNextPhases = nextExpected.toSet //expectedCurrPhases - firstPhase ++ syntheticNextPhases + (allPhases.toList, expectedCurrPhases, expectedNextPhases) + + for (expectedCurr, recordedCurr) <- allPhasePlan.zip(progressCallback.progressPhasesFinal.map(_.curr)) do + assertEquals(s"Phase $recordedCurr was not expected", expectedCurr, recordedCurr) + + val (seenCurrPhases, seenNextPhases) = + val (currs0, nexts0) = progressCallback.progressPhasesFinal.unzip(Tuple.fromProductTyped) + (currs0.toSet, nexts0.toSet) + + val missingCurrPhases = expectedCurrPhases.diff(seenCurrPhases) + val extraCurrPhases = seenCurrPhases.diff(expectedCurrPhases) + assertTrue(s"these phases were not visited ${missingCurrPhases}", missingCurrPhases.isEmpty) + assertTrue(s"these phases were visited, but not in the real plan ${extraCurrPhases}", extraCurrPhases.isEmpty) + + val missingNextPhases = expectedNextPhases.diff(seenNextPhases) + val extraNextPhases = seenNextPhases.diff(expectedNextPhases) + assertTrue(s"these phases were not planned to visit, but were expected ${missingNextPhases}", missingNextPhases.isEmpty) + assertTrue(s"these phases were planned to visit, but were not in the real plan ${extraNextPhases}", extraNextPhases.isEmpty) + + + /** Assert that the phases recorded per unit match the actual phases ran on them */ + def assertExpectedPhasesForUnits(progressCallback: TestProgressCallback, expectedPhases: Seq[String])(using Context): Unit = + for (unit, visitedPhases) <- progressCallback.unitPhases do + val uniquePhases = visitedPhases.toSet + assert(unit != NoCompilationUnit, s"unexpected NoCompilationUnit for phases $uniquePhases") + val duplicatePhases = visitedPhases.view.groupBy(identity).values.filter(_.size > 1).map(_.head) + assertEquals(s"some phases were visited twice for $unit! ${duplicatePhases.toList}", visitedPhases.size, uniquePhases.size) + val unvisitedPhases = expectedPhases.filterNot(visitedPhases.contains) + val extraPhases = visitedPhases.filterNot(expectedPhases.contains) + assertTrue(s"these phases were not visited for $unit ${unvisitedPhases}", unvisitedPhases.isEmpty) + assertTrue(s"these phases were visited for $unit, but not expected ${extraPhases}", extraPhases.isEmpty) + + /** Assert that the number of total units of work matches the number of files * the runnable phases */ + def assertTotalUnits(progressCallback: TestProgressCallback)(using Context): Unit = + var fileTraversals = 0 // files * phases + for (_, phases) <- progressCallback.unitPhases do + fileTraversals += phases.size + val expectedTotal = fileTraversals // assume that no late enters occur + progressCallback.totalEvents match + case Nil => fail("No total events recorded") + case TotalEvent(total, _) :: _ => + assertEquals(expectedTotal, total) + + def inspectProgress( + sources: List[String], + terminalPhase: Option[String] = Some("typer"), + cancellation: Option[TestProgressCallback => Boolean] = None)( + op: Context ?=> TestProgressCallback => Unit)(using Context) = + for cancelNow <- cancellation do + testProgressCallback.withCancelNow(cancelNow) + val sources0 = sources.map(_.linesIterator.map(_.trim.nn).filterNot(_.isEmpty).mkString("\n|").stripMargin) + val terminalPhase0 = terminalPhase.getOrElse(defaultCompiler.phases.last.last.phaseName) + checkAfterCompile(terminalPhase0, sources0) { case given Context => + op(testProgressCallback) + } + + private def testProgressCallback(using Context): TestProgressCallback = + ctx.progressCallback match + case cb: TestProgressCallback => cb + case _ => + fail(s"Expected TestProgressCallback but got ${ctx.progressCallback}") + ??? + + override protected def initializeCtx(fc: FreshContext): Unit = + super.initializeCtx( + fc.setProgressCallback(TestProgressCallback()) + .setSetting(fc.settings.outputDir, new VirtualDirectory("")) + ) + +object ProgressCallbackTest: + + case class TotalEvent(total: Int, atPhase: String) + case class ProgressEvent(curr: Int, total: Int, currPhase: String, nextPhase: String) + case class PhaseTransition(curr: String, next: String) + + def asSubphases(phase: Phase): IndexedSeq[String] = + val subPhases = Run.SubPhases(phase) + val indices = 0 until phase.traversals + indices.map(subPhases.subPhase) + + def runnableSubPhases(using Context): IndexedSeq[String] = + ctx.base.allPhases.filter(_.isRunnable).flatMap(asSubphases).toIndexedSeq + + def allSubPhases(using Context): IndexedSeq[String] = + ctx.base.allPhases.flatMap(asSubphases).toIndexedSeq + + private val syntheticNextPhases = List("") + + /** Asserts that the computed phase name exists in the real phase plan */ + def indexOrFail(allPhasePlan: Array[String], phaseName: String): Int = + val i = allPhasePlan.indexOf(phaseName) + if i < 0 then + fail(s"Phase $phaseName not found") + i + + final class TestProgressCallback extends interfaces.ProgressCallback: + import collection.immutable, immutable.SeqMap + + private var _cancelled: Boolean = false + private var _unitPhases: SeqMap[CompilationUnit, List[String]] = immutable.SeqMap.empty // preserve order + private var _totalEvents: List[TotalEvent] = List.empty + private var _latestProgress: Option[ProgressEvent] = None + private var _progressPhases: List[PhaseTransition] = List.empty + private var _shouldCancelNow: TestProgressCallback => Boolean = _ => false + + def totalEvents = _totalEvents + def latestProgress = _latestProgress + def unitPhases = _unitPhases + def progressPhasesFinal = _progressPhases.reverse + def currentPhase = _progressPhases.headOption.map(_.curr) + + def withCancelNow(f: TestProgressCallback => Boolean): this.type = + _shouldCancelNow = f + this + + override def cancel(): Unit = _cancelled = true + override def isCancelled(): Boolean = _cancelled + + override def informUnitStarting(phase: String, unit: CompilationUnit): Unit = + _unitPhases += (unit -> (unitPhases.getOrElse(unit, Nil) :+ phase)) + + override def progress(current: Int, total: Int, currPhase: String, nextPhase: String): Boolean = + // record the total and current phase whenever the total changes + _totalEvents = _totalEvents match + case Nil => TotalEvent(total, currPhase) :: Nil + case events @ (head :: _) if head.total != total => TotalEvent(total, currPhase) :: events + case events => events + + _latestProgress = Some(ProgressEvent(current, total, currPhase, nextPhase)) + + // record the current and next phase whenever the current phase changes + _progressPhases = _progressPhases match + case all @ PhaseTransition(head, _) :: rest => + if head != currPhase then + PhaseTransition(currPhase, nextPhase) :: all + else + all + case Nil => PhaseTransition(currPhase, nextPhase) :: Nil + + !_shouldCancelNow(this) + +end ProgressCallbackTest diff --git a/sbt-bridge/src/dotty/tools/xsbt/CompilerBridge.java b/sbt-bridge/src/dotty/tools/xsbt/CompilerBridge.java index 92b8062700c4..6e2095a9df1e 100644 --- a/sbt-bridge/src/dotty/tools/xsbt/CompilerBridge.java +++ b/sbt-bridge/src/dotty/tools/xsbt/CompilerBridge.java @@ -19,6 +19,6 @@ public final class CompilerBridge implements CompilerInterface2 { public void run(VirtualFile[] sources, DependencyChanges changes, String[] options, Output output, AnalysisCallback callback, Reporter delegate, CompileProgress progress, Logger log) { CompilerBridgeDriver driver = new CompilerBridgeDriver(options, output); - driver.run(sources, callback, log, delegate); + driver.run(sources, callback, log, delegate, progress); } } diff --git a/sbt-bridge/src/dotty/tools/xsbt/CompilerBridgeDriver.java b/sbt-bridge/src/dotty/tools/xsbt/CompilerBridgeDriver.java index c5c2e0adaef4..2d54d4e83404 100644 --- a/sbt-bridge/src/dotty/tools/xsbt/CompilerBridgeDriver.java +++ b/sbt-bridge/src/dotty/tools/xsbt/CompilerBridgeDriver.java @@ -21,6 +21,7 @@ import xsbti.Problem; import xsbti.*; import xsbti.compile.Output; +import xsbti.compile.CompileProgress; import java.io.IOException; import java.io.InputStream; @@ -82,7 +83,8 @@ private static void reportMissingFile(DelegatingReporter reporter, SourceFile so reporter.reportBasicWarning(message); } - synchronized public void run(VirtualFile[] sources, AnalysisCallback callback, Logger log, Reporter delegate) { + synchronized public void run( + VirtualFile[] sources, AnalysisCallback callback, Logger log, Reporter delegate, CompileProgress progress) { VirtualFile[] sortedSources = new VirtualFile[sources.length]; System.arraycopy(sources, 0, sortedSources, 0, sources.length); Arrays.sort(sortedSources, (x0, x1) -> x0.id().compareTo(x1.id())); @@ -111,6 +113,8 @@ synchronized public void run(VirtualFile[] sources, AnalysisCallback callback, L return sourceFile.path(); }); + ProgressCallbackImpl progressCallback = new ProgressCallbackImpl(progress); + IncrementalCallback incCallback = new IncrementalCallback(callback, sourceFile -> asVirtualFile(sourceFile, reporter, lookup) ); @@ -121,7 +125,8 @@ synchronized public void run(VirtualFile[] sources, AnalysisCallback callback, L Contexts.Context initialCtx = initCtx() .fresh() .setReporter(reporter) - .setIncCallback(incCallback); + .setIncCallback(incCallback) + .setProgressCallback(progressCallback); Contexts.Context context = setup(args, initialCtx).map(t -> t._2).getOrElse(() -> initialCtx); diff --git a/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java b/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java new file mode 100644 index 000000000000..f5fb78f12bb1 --- /dev/null +++ b/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java @@ -0,0 +1,35 @@ +package dotty.tools.xsbt; + +import dotty.tools.dotc.sbt.interfaces.ProgressCallback; +import dotty.tools.dotc.CompilationUnit; + +import xsbti.compile.CompileProgress; + +public final class ProgressCallbackImpl implements ProgressCallback { + private boolean _cancelled = false; // TODO: atomic boolean? + private final CompileProgress _progress; + + public ProgressCallbackImpl(CompileProgress progress) { + _progress = progress; + } + + @Override + public void cancel() { + _cancelled = true; + } + + @Override + public boolean isCancelled() { + return _cancelled; + } + + @Override + public void informUnitStarting(String phase, CompilationUnit unit) { + _progress.startUnit(phase, unit.source().file().path()); + } + + @Override + public boolean progress(int current, int total, String currPhase, String nextPhase) { + return _progress.advance(current, total, currPhase, nextPhase); + } +} diff --git a/sbt-bridge/test/xsbt/CompileProgressSpecification.scala b/sbt-bridge/test/xsbt/CompileProgressSpecification.scala new file mode 100644 index 000000000000..bcdac0547e75 --- /dev/null +++ b/sbt-bridge/test/xsbt/CompileProgressSpecification.scala @@ -0,0 +1,79 @@ +package xsbt + +import org.junit.{ Test, Ignore } +import org.junit.Assert._ + +/**Only does some rudimentary checks to assert compat with sbt. + * More thorough tests are found in compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala + */ +class CompileProgressSpecification { + + @Test + def totalIsMoreWhenSourcePath = { + val srcA = """class A""" + val srcB = """class B""" + val extraC = """trait C""" // will only exist in the `-sourcepath`, causing a late compile + val extraD = """trait D""" // will only exist in the `-sourcepath`, causing a late compile + val srcE = """class E extends C""" // depends on class in the sourcepath + val srcF = """class F extends C, D""" // depends on classes in the sourcepath + + val compilerForTesting = new ScalaCompilerForUnitTesting + + val totalA = compilerForTesting.extractTotal(srcA)() + assertTrue("expected more than 1 unit of work for a single file", totalA > 1) + + val totalB = compilerForTesting.extractTotal(srcA, srcB)() + assertEquals("expected twice the work for two sources", totalA * 2, totalB) + + val totalC = compilerForTesting.extractTotal(srcA, srcE)(extraC) + assertEquals("expected 2x+1 the work for two sources, and 1 late compile", totalA * 2 + 1, totalC) + + val totalD = compilerForTesting.extractTotal(srcA, srcF)(extraC, extraD) + assertEquals("expected 2x+2 the work for two sources, and 2 late compiles", totalA * 2 + 2, totalD) + } + + @Test + def multipleFilesVisitSamePhases = { + val srcA = """class A""" + val srcB = """class B""" + val compilerForTesting = new ScalaCompilerForUnitTesting + val Seq(phasesA, phasesB) = compilerForTesting.extractEnteredPhases(srcA, srcB) + assertTrue("expected some phases, was empty", phasesA.nonEmpty) + assertEquals(phasesA, phasesB) + } + + @Test + def multipleFiles = { + val srcA = """class A""" + val srcB = """class B""" + val compilerForTesting = new ScalaCompilerForUnitTesting + val allPhases = compilerForTesting.extractProgressPhases(srcA, srcB) + assertTrue("expected some phases, was empty", allPhases.nonEmpty) + val someExpectedPhases = // just check some "fundamental" phases, don't put all phases to avoid brittleness + Set( + "parser", + "typer[indexing]", "typer[typechecking]", "typer[checkingJava]", + "sbt-deps", + "posttyper", + "sbt-api", + "SetRootTree", + "pickler", + "inlining", + "postInlining", + "staging", + "splicing", + "pickleQuotes", + "MegaPhase{pruneErasedDefs,...,arrayConstructors}", + "erasure", + "constructors", + "genSJSIR", + "genBCode" + ) + val missingExpectedPhases = someExpectedPhases -- allPhases.toSet + val msgIfMissing = + s"missing expected phases: $missingExpectedPhases. " + + s"Either the compiler phases changed, or the encoding of Run.SubPhases.subPhase" + assertTrue(msgIfMissing, missingExpectedPhases.isEmpty) + } + +} diff --git a/sbt-bridge/test/xsbt/ExtractUsedNamesSpecification.scala b/sbt-bridge/test/xsbt/ExtractUsedNamesSpecification.scala index 819bedec3cbc..2b2b7d26c716 100644 --- a/sbt-bridge/test/xsbt/ExtractUsedNamesSpecification.scala +++ b/sbt-bridge/test/xsbt/ExtractUsedNamesSpecification.scala @@ -1,6 +1,7 @@ package xsbt import xsbti.UseScope +import ScalaCompilerForUnitTesting.Callbacks import org.junit.{ Test, Ignore } import org.junit.Assert._ @@ -226,7 +227,7 @@ class ExtractUsedNamesSpecification { def findPatMatUsages(in: String): Set[String] = { val compilerForTesting = new ScalaCompilerForUnitTesting - val (_, callback) = + val (_, Callbacks(callback, _)) = compilerForTesting.compileSrcs(List(List(sealedClass, in))) val clientNames = callback.usedNamesAndScopes.view.filterKeys(!_.startsWith("base.")) diff --git a/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala b/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala index 51f10e90f932..87bc45744e21 100644 --- a/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala +++ b/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala @@ -13,6 +13,10 @@ import dotty.tools.io.PlainFile.toPlainFile import dotty.tools.xsbt.CompilerBridge import TestCallback.ExtractedClassDependencies +import ScalaCompilerForUnitTesting.Callbacks + +object ScalaCompilerForUnitTesting: + case class Callbacks(analysis: TestCallback, progress: TestCompileProgress) /** * Provides common functionality needed for unit tests that require compiling @@ -20,12 +24,29 @@ import TestCallback.ExtractedClassDependencies */ class ScalaCompilerForUnitTesting { + def extractEnteredPhases(srcs: String*): Seq[List[String]] = { + val (tempSrcFiles, Callbacks(_, testProgress)) = compileSrcs(srcs: _*) + val run = testProgress.runs.head + tempSrcFiles.map(src => run.unitPhases(src.id)) + } + + def extractTotal(srcs: String*)(extraSourcePath: String*): Int = { + val (tempSrcFiles, Callbacks(_, testProgress)) = compileSrcs(List(srcs.toList), extraSourcePath.toList) + val run = testProgress.runs.head + run.total + } + + def extractProgressPhases(srcs: String*): List[String] = { + val (_, Callbacks(_, testProgress)) = compileSrcs(srcs: _*) + testProgress.runs.head.phases + } + /** * Compiles given source code using Scala compiler and returns API representation * extracted by ExtractAPI class. */ def extractApiFromSrc(src: String): Seq[ClassLike] = { - val (Seq(tempSrcFile), analysisCallback) = compileSrcs(src) + val (Seq(tempSrcFile), Callbacks(analysisCallback, _)) = compileSrcs(src) analysisCallback.apis(tempSrcFile) } @@ -34,7 +55,7 @@ class ScalaCompilerForUnitTesting { * extracted by ExtractAPI class. */ def extractApisFromSrcs(srcs: List[String]*): Seq[Seq[ClassLike]] = { - val (tempSrcFiles, analysisCallback) = compileSrcs(srcs.toList) + val (tempSrcFiles, Callbacks(analysisCallback, _)) = compileSrcs(srcs.toList) tempSrcFiles.map(analysisCallback.apis) } @@ -52,7 +73,7 @@ class ScalaCompilerForUnitTesting { assertDefaultScope: Boolean = true ): Map[String, Set[String]] = { // we drop temp src file corresponding to the definition src file - val (Seq(_, tempSrcFile), analysisCallback) = compileSrcs(definitionSrc, actualSrc) + val (Seq(_, tempSrcFile), Callbacks(analysisCallback, _)) = compileSrcs(definitionSrc, actualSrc) if (assertDefaultScope) for { (className, used) <- analysisCallback.usedNamesAndScopes @@ -70,7 +91,7 @@ class ScalaCompilerForUnitTesting { * Only the names used in the last src file are returned. */ def extractUsedNamesFromSrc(sources: String*): Map[String, Set[String]] = { - val (srcFiles, analysisCallback) = compileSrcs(sources: _*) + val (srcFiles, Callbacks(analysisCallback, _)) = compileSrcs(sources: _*) srcFiles .map { srcFile => val classesInSrc = analysisCallback.classNames(srcFile).map(_._1) @@ -92,7 +113,7 @@ class ScalaCompilerForUnitTesting { * file system-independent way of testing dependencies between source code "files". */ def extractDependenciesFromSrcs(srcs: List[List[String]]): ExtractedClassDependencies = { - val (_, testCallback) = compileSrcs(srcs) + val (_, Callbacks(testCallback, _)) = compileSrcs(srcs) val memberRefDeps = testCallback.classDependencies collect { case (target, src, DependencyByMemberRef) => (src, target) @@ -121,16 +142,22 @@ class ScalaCompilerForUnitTesting { * The sequence of temporary files corresponding to passed snippets and analysis * callback is returned as a result. */ - def compileSrcs(groupedSrcs: List[List[String]]): (Seq[VirtualFile], TestCallback) = { + def compileSrcs(groupedSrcs: List[List[String]], sourcePath: List[String] = Nil): (Seq[VirtualFile], Callbacks) = { val temp = IO.createTemporaryDirectory val analysisCallback = new TestCallback + val testProgress = new TestCompileProgress val classesDir = new File(temp, "classes") classesDir.mkdir() val bridge = new CompilerBridge - val files = for ((compilationUnit, unitId) <- groupedSrcs.zipWithIndex) yield { - val srcFiles = compilationUnit.toSeq.zipWithIndex.map { + val files = for ((compilationUnits, unitId) <- groupedSrcs.zipWithIndex) yield { + val extraFiles = sourcePath.toSeq.zipWithIndex.map { + case (src, i) => + val fileName = s"Extra-$unitId-$i.scala" + prepareSrcFile(temp, fileName, src) + } + val srcFiles = compilationUnits.toSeq.zipWithIndex.map { (src, i) => val fileName = s"Test-$unitId-$i.scala" prepareSrcFile(temp, fileName, src) @@ -141,23 +168,27 @@ class ScalaCompilerForUnitTesting { val output = new SingleOutput: def getOutputDirectory() = classesDir + val maybeSourcePath = if extraFiles.isEmpty then Nil else List("-sourcepath", temp.getAbsolutePath.toString) + bridge.run( virtualSrcFiles, new TestDependencyChanges, - Array("-Yforce-sbt-phases", "-classpath", classesDirPath, "-usejavacp", "-d", classesDirPath), + Array("-Yforce-sbt-phases", "-classpath", classesDirPath, "-usejavacp", "-d", classesDirPath) ++ maybeSourcePath, output, analysisCallback, new TestReporter, - new CompileProgress {}, + testProgress, new TestLogger ) + testProgress.completeRun() + srcFiles } - (files.flatten.toSeq, analysisCallback) + (files.flatten.toSeq, Callbacks(analysisCallback, testProgress)) } - def compileSrcs(srcs: String*): (Seq[VirtualFile], TestCallback) = { + def compileSrcs(srcs: String*): (Seq[VirtualFile], Callbacks) = { compileSrcs(List(srcs.toList)) } diff --git a/sbt-bridge/test/xsbti/TestCompileProgress.scala b/sbt-bridge/test/xsbti/TestCompileProgress.scala new file mode 100644 index 000000000000..d5dc81dfda24 --- /dev/null +++ b/sbt-bridge/test/xsbti/TestCompileProgress.scala @@ -0,0 +1,33 @@ +package xsbti + +import xsbti.compile.CompileProgress + +import scala.collection.mutable + +class TestCompileProgress extends CompileProgress: + class Run: + private[TestCompileProgress] val _phases: mutable.Set[String] = mutable.LinkedHashSet.empty + private[TestCompileProgress] val _unitPhases: mutable.Map[String, mutable.Set[String]] = mutable.LinkedHashMap.empty + private[TestCompileProgress] var _latestTotal: Int = 0 + + def phases: List[String] = _phases.toList + def unitPhases: collection.MapView[String, List[String]] = _unitPhases.view.mapValues(_.toList) + def total: Int = _latestTotal + + private val _runs: mutable.ListBuffer[Run] = mutable.ListBuffer.empty + private var _currentRun: Run = new Run + + def runs: List[Run] = _runs.toList + + def completeRun(): Unit = + _runs += _currentRun + _currentRun = new Run + + override def startUnit(phase: String, unitPath: String): Unit = + _currentRun._unitPhases.getOrElseUpdate(unitPath, mutable.LinkedHashSet.empty) += phase + + override def advance(current: Int, total: Int, prevPhase: String, nextPhase: String): Boolean = + _currentRun._phases += prevPhase + _currentRun._phases += nextPhase + _currentRun._latestTotal = total + true diff --git a/scaladoc/src/scala/tasty/inspector/TastyInspector.scala b/scaladoc/src/scala/tasty/inspector/TastyInspector.scala index 00aa6c5e3771..14e5f019b433 100644 --- a/scaladoc/src/scala/tasty/inspector/TastyInspector.scala +++ b/scaladoc/src/scala/tasty/inspector/TastyInspector.scala @@ -69,6 +69,7 @@ object TastyInspector: override def phaseName: String = "tastyInspector" override def runOn(units: List[CompilationUnit])(using ctx0: Context): List[CompilationUnit] = + // NOTE: although this is a phase, do not expect this to be ran with an xsbti.CompileProgress val ctx = QuotesCache.init(ctx0.fresh) runOnImpl(units)(using ctx) diff --git a/staging/src/scala/quoted/staging/QuoteCompiler.scala b/staging/src/scala/quoted/staging/QuoteCompiler.scala index eee2dacdc5f5..9fee0e41efd1 100644 --- a/staging/src/scala/quoted/staging/QuoteCompiler.scala +++ b/staging/src/scala/quoted/staging/QuoteCompiler.scala @@ -62,6 +62,7 @@ private class QuoteCompiler extends Compiler: def phaseName: String = "quotedFrontend" override def runOn(units: List[CompilationUnit])(implicit ctx: Context): List[CompilationUnit] = + // NOTE: although this is a phase, there is no need to track xsbti.CompileProgress here. units.flatMap { case exprUnit: ExprCompilationUnit => val ctx1 = ctx.fresh.setPhase(this.start).setCompilationUnit(exprUnit) diff --git a/tasty-inspector/src/scala/tasty/inspector/TastyInspector.scala b/tasty-inspector/src/scala/tasty/inspector/TastyInspector.scala index 4c6440530ba2..e70d2d4f6dc5 100644 --- a/tasty-inspector/src/scala/tasty/inspector/TastyInspector.scala +++ b/tasty-inspector/src/scala/tasty/inspector/TastyInspector.scala @@ -66,6 +66,7 @@ object TastyInspector: override def phaseName: String = "tastyInspector" override def runOn(units: List[CompilationUnit])(using ctx0: Context): List[CompilationUnit] = + // NOTE: although this is a phase, do not expect this to be ran with an xsbti.CompileProgress val ctx = QuotesCache.init(ctx0.fresh) runOnImpl(units)(using ctx)