Skip to content

Commit

Permalink
simplify subphase traversal
Browse files Browse the repository at this point in the history
  • Loading branch information
bishabosha committed Oct 25, 2023
1 parent b072662 commit b510772
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 41 deletions.
40 changes: 25 additions & 15 deletions compiler/src/dotty/tools/dotc/Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,15 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
// no subphases were ran, remove traversals from expected total
progress.totalTraversals -= currentPhase.traversals

private def doAdvanceSubPhase()(using Context): Unit =
private def tryAdvanceSubPhase()(using Context): Unit =
trackProgress: progress =>
progress.currentUnitCount = 0 // reset unit count in current (sub)phase
progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase
progress.completedTraversalCount += 1 // add an extra traversal now that we completed a (sub)phase
progress.currentCompletedSubtraversalCount += 1 // record that we've seen a subphase
if !progress.isCancelled() then
progress.tickSubphase()
if progress.canAdvanceSubPhase then
progress.currentUnitCount = 0 // reset unit count in current (sub)phase
progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase
progress.completedTraversalCount += 1 // add an extra traversal now that we completed a (sub)phase
progress.currentCompletedSubtraversalCount += 1 // record that we've seen a subphase
if !progress.isCancelled() then
progress.tickSubphase()

/** Will be set to true if any of the compiled compilation units contains
* a pureFunctions language import.
Expand Down Expand Up @@ -476,20 +477,25 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint

object Run {

case class SubPhase(val name: String):
override def toString: String = name

class SubPhases(val phase: Phase):
require(phase.exists)

private def baseName: String = phase match
case phase: MegaPhase => phase.shortPhaseName
case phase => phase.phaseName

val all = IArray.from(phase.subPhases.map(sub => s"$baseName ($sub)"))
val all = IArray.from(phase.subPhases.map(sub => s"$baseName[$sub]"))

def next(using Context): Option[SubPhases] =
val next0 = phase.megaPhase.next.megaPhase
if next0.exists then Some(SubPhases(next0))
else None

def size: Int = all.size

def subPhase(index: Int) =
if index < all.size then all(index)
else baseName
Expand All @@ -511,14 +517,17 @@ object Run {
private var nextPhaseName: String = uninitialized // initialized by enterPhase

/** Enter into a new real phase, setting the current and next (sub)phases */
private[Run] def enterPhase(newPhase: Phase)(using Context): Unit =
def enterPhase(newPhase: Phase)(using Context): Unit =
if newPhase ne currPhase then
currPhase = newPhase
subPhases = SubPhases(newPhase)
tickSubphase()

def canAdvanceSubPhase: Boolean =
currentCompletedSubtraversalCount + 1 < subPhases.size

/** Compute the current (sub)phase name and next (sub)phase name */
private[Run] def tickSubphase()(using Context): Unit =
def tickSubphase()(using Context): Unit =
val index = currentCompletedSubtraversalCount
val s = subPhases
currPhaseName = s.subPhase(index)
Expand Down Expand Up @@ -547,20 +556,20 @@ object Run {
private def requireInitialized(): Unit =
require((currPhase: Phase | Null) != null, "enterPhase was not called")

private[Run] def checkCancellation(): Boolean =
def checkCancellation(): Boolean =
if Thread.interrupted() then cancel()
isCancelled()

/** trace that we are beginning a unit in the current (sub)phase, unless cancelled */
private[Run] def tryEnterUnit(unit: CompilationUnit): Boolean =
def tryEnterUnit(unit: CompilationUnit): Boolean =
if checkCancellation() then false
else
requireInitialized()
cb.informUnitStarting(currPhaseName, unit)
true

/** trace the current progress out of the total, in the current (sub)phase, reporting the next (sub)phase */
private[Run] def refreshProgress()(using Context): Unit =
def refreshProgress()(using Context): Unit =
requireInitialized()
val total = totalProgress()
if total > 0 && !cb.progress(currentProgress(), total, currPhaseName, nextPhaseName) then
Expand All @@ -582,8 +591,9 @@ object Run {
def advanceUnit()(using Context): Unit =
if run != null then run.doAdvanceUnit()

def advanceSubPhase()(using Context): Unit =
if run != null then run.doAdvanceSubPhase()
/** if there exists another subphase, switch to it and record progress */
def enterNextSubphase()(using Context): Unit =
if run != null then run.tryAdvanceSubPhase()

/** advance the late count and record progress in the current phase */
def advanceLate()(using Context): Unit =
Expand Down
9 changes: 8 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Phases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ object Phases {
def runsAfter: Set[String] = Set.empty

/** for purposes of progress tracking, overridden in TyperPhase */
def subPhases: List[String] = Nil
def subPhases: List[Run.SubPhase] = Nil
final def traversals: Int = if subPhases.isEmpty then 1 else subPhases.length

/** @pre `isRunnable` returns true */
Expand Down Expand Up @@ -472,6 +472,13 @@ object Phases {
else
false

inline def runSubPhase[T](id: Run.SubPhase)(inline body: (Run.SubPhase, Context) ?=> T)(using Context): T =
given Run.SubPhase = id
try
body
finally
ctx.run.enterNextSubphase()

/** Do not run if compile progress has been cancelled */
final def cancellable(body: Context ?=> Unit)(using Context): Boolean =
if ctx.run.enterRegion() then
Expand Down
44 changes: 21 additions & 23 deletions compiler/src/dotty/tools/dotc/typer/TyperPhase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dotc
package typer

import core._
import Run.SubPhase
import Phases._
import Contexts._
import Symbols._
Expand Down Expand Up @@ -31,13 +32,13 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
// Run regardless of parsing errors
override def isRunnable(implicit ctx: Context): Boolean = true

def enterSyms(using Context): Boolean = monitor("indexing") {
def enterSyms(using Context)(using subphase: SubPhase): Boolean = monitor(subphase.name) {
val unit = ctx.compilationUnit
ctx.typer.index(unit.untpdTree)
typr.println("entered: " + unit.source)
}

def typeCheck(using Context): Boolean = monitor("typechecking") {
def typeCheck(using Context)(using subphase: SubPhase): Boolean = monitor(subphase.name) {
val unit = ctx.compilationUnit
try
if !unit.suspended then
Expand All @@ -49,7 +50,7 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
catch case _: CompilationUnit.SuspendException => ()
}

def javaCheck(using Context): Boolean = monitor("checking java") {
def javaCheck(using Context)(using subphase: SubPhase): Boolean = monitor(subphase.name) {
val unit = ctx.compilationUnit
if unit.isJava then
JavaChecks.check(unit.tpdTree)
Expand All @@ -58,10 +59,11 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
protected def discardAfterTyper(unit: CompilationUnit)(using Context): Boolean =
unit.isJava || unit.suspended

/** Keep synchronised with `monitor` subcalls */
override def subPhases: List[String] = List("indexing", "typechecking", "checking java")
override val subPhases: List[SubPhase] = List(
SubPhase("indexing"), SubPhase("typechecking"), SubPhase("checkingJava"))

override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] =
val List(Indexing @ _, Typechecking @ _, CheckingJava @ _) = subPhases: @unchecked
val unitContexts =
for unit <- units yield
val newCtx0 = ctx.fresh.setPhase(this.start).setCompilationUnit(unit)
Expand All @@ -72,14 +74,12 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
else
newCtx

val unitContexts0 =
try
for
unitContext <- unitContexts
if enterSyms(using unitContext)
yield unitContext
finally
ctx.run.advanceSubPhase() // tick from "typer (indexing)" to "typer (typechecking)"
val unitContexts0 = runSubPhase(Indexing) {
for
unitContext <- unitContexts
if enterSyms(using unitContext)
yield unitContext
}

ctx.base.parserPhase match {
case p: ParserPhase =>
Expand All @@ -91,23 +91,21 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
case _ =>
}

val unitContexts1 =
try
for
unitContext <- unitContexts0
if typeCheck(using unitContext)
yield unitContext
finally
ctx.run.advanceSubPhase() // tick from "typer (typechecking)" to "typer (java checking)"
val unitContexts1 = runSubPhase(Typechecking) {
for
unitContext <- unitContexts0
if typeCheck(using unitContext)
yield unitContext
}

record("total trees after typer", ast.Trees.ntrees)

val unitContexts2 =
val unitContexts2 = runSubPhase(CheckingJava) {
for
unitContext <- unitContexts1
if javaCheck(using unitContext) // after typechecking to avoid cycles
yield unitContext

}
val newUnits = unitContexts2.map(_.compilationUnit).filterNot(discardAfterTyper)
ctx.run.nn.checkSuspendedUnits(newUnits)
newUnits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ final class ProgressCallbackTest extends DottyTest:

@Test
def cancelMidTyper: Unit =
inspectCancellationAtPhase("typer (typechecking)")
inspectCancellationAtPhase("typer[typechecking]")

@Test
def cancelErasure: Unit =
Expand Down
2 changes: 1 addition & 1 deletion sbt-bridge/test/xsbt/CompileProgressSpecification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CompileProgressSpecification {
val someExpectedPhases = // just check some "fundamental" phases, don't put all phases to avoid brittleness
Set(
"parser",
"typer (indexing)", "typer (typechecking)", "typer (checking java)",
"typer[indexing]", "typer[typechecking]", "typer[checkingJava]",
"sbt-deps",
"posttyper",
"sbt-api",
Expand Down

0 comments on commit b510772

Please sign in to comment.