Skip to content

Commit

Permalink
change encoding MegaPhase name in progress tracking
Browse files Browse the repository at this point in the history
also add sbt-bridge test for CompileProgress
  • Loading branch information
bishabosha committed Oct 24, 2023
1 parent ef9fabc commit c0190c2
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 18 deletions.
10 changes: 8 additions & 2 deletions compiler/src/dotty/tools/dotc/Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import scala.io.Codec

import Run.Progress
import scala.compiletime.uninitialized
import dotty.tools.dotc.transform.MegaPhase

/** A compiler run. Exports various methods to compile source files */
class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with ConstraintRunInfo {
Expand Down Expand Up @@ -463,7 +464,11 @@ object Run {
class SubPhases(val phase: Phase):
require(phase.exists)

val all = IArray.from(phase.subPhases.map(sub => s"${phase.phaseName} ($sub)"))
private def baseName: String = phase match
case phase: MegaPhase => phase.shortPhaseName
case phase => phase.phaseName

val all = IArray.from(phase.subPhases.map(sub => s"$baseName ($sub)"))

def next(using Context): Option[SubPhases] =
val next0 = phase.megaPhase.next.megaPhase
Expand All @@ -472,7 +477,8 @@ object Run {

def subPhase(index: Int) =
if index < all.size then all(index)
else phase.phaseName
else baseName


private class Progress(cb: ProgressCallback, private val run: Run, val initialTraversals: Int):
private[Run] var totalTraversals: Int = initialTraversals // track how many phases we expect to run
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/MegaPhase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase {
if (miniPhases.length == 1) miniPhases(0).phaseName
else miniPhases.map(_.phaseName).mkString("MegaPhase{", ", ", "}")

/** Used in progress reporting to avoid super long phase names, also the precision is not so important here */
lazy val shortPhaseName: String =
if (miniPhases.length == 1) miniPhases(0).phaseName
else
s"MegaPhase{${miniPhases.head.phaseName},...,${miniPhases.last.phaseName}}"

private var relaxedTypingCache: Boolean = _
private var relaxedTypingKnown = false

Expand Down
56 changes: 56 additions & 0 deletions sbt-bridge/test/xsbt/CompileProgressSpecification.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package xsbt

import org.junit.{ Test, Ignore }
import org.junit.Assert._

/**Only does some rudimentary checks to assert compat with sbt.
* More thorough tests are found in compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala
*/
class CompileProgressSpecification {

@Test
def multipleFilesVisitSamePhases = {
val srcA = """class A"""
val srcB = """class B"""
val compilerForTesting = new ScalaCompilerForUnitTesting
val Seq(phasesA, phasesB) = compilerForTesting.extractEnteredPhases(srcA, srcB)
assertTrue("expected some phases, was empty", phasesA.nonEmpty)
assertEquals(phasesA, phasesB)
}

@Test
def multipleFiles = {
val srcA = """class A"""
val srcB = """class B"""
val compilerForTesting = new ScalaCompilerForUnitTesting
val allPhases = compilerForTesting.extractProgressPhases(srcA, srcB)
assertTrue("expected some phases, was empty", allPhases.nonEmpty)
val someExpectedPhases = // just check some "fundamental" phases, don't put all phases to avoid brittleness
Set(
"parser",
"typer (indexing)", "typer (typechecking)", "typer (checking java)",
"sbt-deps",
"extractSemanticDB",
"posttyper",
"sbt-api",
"SetRootTree",
"pickler",
"inlining",
"postInlining",
"staging",
"splicing",
"pickleQuotes",
"MegaPhase{pruneErasedDefs,...,arrayConstructors}",
"erasure",
"constructors",
"genSJSIR",
"genBCode"
)
val missingExpectedPhases = someExpectedPhases -- allPhases.toSet
val msgIfMissing =
s"missing expected phases: $missingExpectedPhases. " +
s"Either the compiler phases changed, or the encoding of Run.SubPhases.subPhase"
assertTrue(msgIfMissing, missingExpectedPhases.isEmpty)
}

}
3 changes: 2 additions & 1 deletion sbt-bridge/test/xsbt/ExtractUsedNamesSpecification.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package xsbt

import xsbti.UseScope
import ScalaCompilerForUnitTesting.Callbacks

import org.junit.{ Test, Ignore }
import org.junit.Assert._
Expand Down Expand Up @@ -226,7 +227,7 @@ class ExtractUsedNamesSpecification {

def findPatMatUsages(in: String): Set[String] = {
val compilerForTesting = new ScalaCompilerForUnitTesting
val (_, callback) =
val (_, Callbacks(callback, _)) =
compilerForTesting.compileSrcs(List(List(sealedClass, in)))
val clientNames = callback.usedNamesAndScopes.view.filterKeys(!_.startsWith("base."))

Expand Down
37 changes: 27 additions & 10 deletions sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,34 @@ import dotty.tools.io.PlainFile.toPlainFile
import dotty.tools.xsbt.CompilerBridge

import TestCallback.ExtractedClassDependencies
import ScalaCompilerForUnitTesting.Callbacks

object ScalaCompilerForUnitTesting:
case class Callbacks(analysis: TestCallback, progress: TestCompileProgress)

/**
* Provides common functionality needed for unit tests that require compiling
* source code using Scala compiler.
*/
class ScalaCompilerForUnitTesting {

def extractEnteredPhases(srcs: String*): Seq[List[String]] = {
val (tempSrcFiles, Callbacks(_, testProgress)) = compileSrcs(srcs: _*)
val run = testProgress.runs.head
tempSrcFiles.map(src => run.unitPhases(src.id))
}

def extractProgressPhases(srcs: String*): List[String] = {
val (_, Callbacks(_, testProgress)) = compileSrcs(srcs: _*)
testProgress.runs.head.phases
}

/**
* Compiles given source code using Scala compiler and returns API representation
* extracted by ExtractAPI class.
*/
def extractApiFromSrc(src: String): Seq[ClassLike] = {
val (Seq(tempSrcFile), analysisCallback) = compileSrcs(src)
val (Seq(tempSrcFile), Callbacks(analysisCallback, _)) = compileSrcs(src)
analysisCallback.apis(tempSrcFile)
}

Expand All @@ -34,7 +49,7 @@ class ScalaCompilerForUnitTesting {
* extracted by ExtractAPI class.
*/
def extractApisFromSrcs(srcs: List[String]*): Seq[Seq[ClassLike]] = {
val (tempSrcFiles, analysisCallback) = compileSrcs(srcs.toList)
val (tempSrcFiles, Callbacks(analysisCallback, _)) = compileSrcs(srcs.toList)
tempSrcFiles.map(analysisCallback.apis)
}

Expand All @@ -52,7 +67,7 @@ class ScalaCompilerForUnitTesting {
assertDefaultScope: Boolean = true
): Map[String, Set[String]] = {
// we drop temp src file corresponding to the definition src file
val (Seq(_, tempSrcFile), analysisCallback) = compileSrcs(definitionSrc, actualSrc)
val (Seq(_, tempSrcFile), Callbacks(analysisCallback, _)) = compileSrcs(definitionSrc, actualSrc)

if (assertDefaultScope) for {
(className, used) <- analysisCallback.usedNamesAndScopes
Expand All @@ -70,7 +85,7 @@ class ScalaCompilerForUnitTesting {
* Only the names used in the last src file are returned.
*/
def extractUsedNamesFromSrc(sources: String*): Map[String, Set[String]] = {
val (srcFiles, analysisCallback) = compileSrcs(sources: _*)
val (srcFiles, Callbacks(analysisCallback, _)) = compileSrcs(sources: _*)
srcFiles
.map { srcFile =>
val classesInSrc = analysisCallback.classNames(srcFile).map(_._1)
Expand All @@ -92,7 +107,7 @@ class ScalaCompilerForUnitTesting {
* file system-independent way of testing dependencies between source code "files".
*/
def extractDependenciesFromSrcs(srcs: List[List[String]]): ExtractedClassDependencies = {
val (_, testCallback) = compileSrcs(srcs)
val (_, Callbacks(testCallback, _)) = compileSrcs(srcs)

val memberRefDeps = testCallback.classDependencies collect {
case (target, src, DependencyByMemberRef) => (src, target)
Expand Down Expand Up @@ -121,7 +136,7 @@ class ScalaCompilerForUnitTesting {
* The sequence of temporary files corresponding to passed snippets and analysis
* callback is returned as a result.
*/
def compileSrcs(groupedSrcs: List[List[String]]): (Seq[VirtualFile], TestCallback) = {
def compileSrcs(groupedSrcs: List[List[String]]): (Seq[VirtualFile], Callbacks) = {
val temp = IO.createTemporaryDirectory
val analysisCallback = new TestCallback
val testProgress = new TestCompileProgress
Expand All @@ -130,8 +145,8 @@ class ScalaCompilerForUnitTesting {

val bridge = new CompilerBridge

val files = for ((compilationUnit, unitId) <- groupedSrcs.zipWithIndex) yield {
val srcFiles = compilationUnit.toSeq.zipWithIndex.map {
val files = for ((compilationUnits, unitId) <- groupedSrcs.zipWithIndex) yield {
val srcFiles = compilationUnits.toSeq.zipWithIndex.map {
(src, i) =>
val fileName = s"Test-$unitId-$i.scala"
prepareSrcFile(temp, fileName, src)
Expand All @@ -153,12 +168,14 @@ class ScalaCompilerForUnitTesting {
new TestLogger
)

testProgress.completeRun()

srcFiles
}
(files.flatten.toSeq, analysisCallback)
(files.flatten.toSeq, Callbacks(analysisCallback, testProgress))
}

def compileSrcs(srcs: String*): (Seq[VirtualFile], TestCallback) = {
def compileSrcs(srcs: String*): (Seq[VirtualFile], Callbacks) = {
compileSrcs(List(srcs.toList))
}

Expand Down
5 changes: 0 additions & 5 deletions sbt-bridge/test/xsbt/TestCompileProgress.scala

This file was deleted.

30 changes: 30 additions & 0 deletions sbt-bridge/test/xsbti/TestCompileProgress.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package xsbti

import xsbti.compile.CompileProgress

import scala.collection.mutable

class TestCompileProgress extends CompileProgress:
class Run:
private[TestCompileProgress] val _phases: mutable.Set[String] = mutable.LinkedHashSet.empty
private[TestCompileProgress] val _unitPhases: mutable.Map[String, mutable.Set[String]] = mutable.LinkedHashMap.empty

def phases: List[String] = _phases.toList
def unitPhases: collection.MapView[String, List[String]] = _unitPhases.view.mapValues(_.toList)

private val _runs: mutable.ListBuffer[Run] = mutable.ListBuffer.empty
private var _currentRun: Run = new Run

def runs: List[Run] = _runs.toList

def completeRun(): Unit =
_runs += _currentRun
_currentRun = new Run

override def startUnit(phase: String, unitPath: String): Unit =
_currentRun._unitPhases.getOrElseUpdate(unitPath, mutable.LinkedHashSet.empty) += phase

override def advance(current: Int, total: Int, prevPhase: String, nextPhase: String): Boolean =
_currentRun._phases += prevPhase
_currentRun._phases += nextPhase
true

0 comments on commit c0190c2

Please sign in to comment.