diff --git a/sbt-bridge/test/xsbt/CompileProgressSpecification.scala b/sbt-bridge/test/xsbt/CompileProgressSpecification.scala index 2297f8bb441e..32b4f58effdb 100644 --- a/sbt-bridge/test/xsbt/CompileProgressSpecification.scala +++ b/sbt-bridge/test/xsbt/CompileProgressSpecification.scala @@ -8,6 +8,30 @@ import org.junit.Assert._ */ class CompileProgressSpecification { + @Test + def totalIsMoreWhenSourcePath = { + val srcA = """class A""" + val srcB = """class B""" + val extraC = """trait C""" // will only exist in the `-sourcepath`, causing a late compile + val extraD = """trait D""" // will only exist in the `-sourcepath`, causing a late compile + val srcE = """class E extends C""" // depends on class in the sourcepath + val srcF = """class F extends C, D""" // depends on classes in the sourcepath + + val compilerForTesting = new ScalaCompilerForUnitTesting + + val totalA = compilerForTesting.extractTotal(srcA)() + assertTrue("expected more than 1 unit of work for a single file", totalA > 1) + + val totalB = compilerForTesting.extractTotal(srcA, srcB)() + assertEquals("expected twice the work for two sources", totalA * 2, totalB) + + val totalC = compilerForTesting.extractTotal(srcA, srcE)(extraC) + assertEquals("expected 2x+1 the work for two sources, and 1 late compile", totalA * 2 + 1, totalC) + + val totalD = compilerForTesting.extractTotal(srcA, srcF)(extraC, extraD) + assertEquals("expected 2x+2 the work for two sources, and 2 late compiles", totalA * 2 + 2, totalD) + } + @Test def multipleFilesVisitSamePhases = { val srcA = """class A""" diff --git a/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala b/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala index 520b7f7053da..87bc45744e21 100644 --- a/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala +++ b/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala @@ -30,6 +30,12 @@ class ScalaCompilerForUnitTesting { tempSrcFiles.map(src => run.unitPhases(src.id)) } + def extractTotal(srcs: String*)(extraSourcePath: String*): Int = { + val (tempSrcFiles, Callbacks(_, testProgress)) = compileSrcs(List(srcs.toList), extraSourcePath.toList) + val run = testProgress.runs.head + run.total + } + def extractProgressPhases(srcs: String*): List[String] = { val (_, Callbacks(_, testProgress)) = compileSrcs(srcs: _*) testProgress.runs.head.phases @@ -136,7 +142,7 @@ class ScalaCompilerForUnitTesting { * The sequence of temporary files corresponding to passed snippets and analysis * callback is returned as a result. */ - def compileSrcs(groupedSrcs: List[List[String]]): (Seq[VirtualFile], Callbacks) = { + def compileSrcs(groupedSrcs: List[List[String]], sourcePath: List[String] = Nil): (Seq[VirtualFile], Callbacks) = { val temp = IO.createTemporaryDirectory val analysisCallback = new TestCallback val testProgress = new TestCompileProgress @@ -146,6 +152,11 @@ class ScalaCompilerForUnitTesting { val bridge = new CompilerBridge val files = for ((compilationUnits, unitId) <- groupedSrcs.zipWithIndex) yield { + val extraFiles = sourcePath.toSeq.zipWithIndex.map { + case (src, i) => + val fileName = s"Extra-$unitId-$i.scala" + prepareSrcFile(temp, fileName, src) + } val srcFiles = compilationUnits.toSeq.zipWithIndex.map { (src, i) => val fileName = s"Test-$unitId-$i.scala" @@ -157,10 +168,12 @@ class ScalaCompilerForUnitTesting { val output = new SingleOutput: def getOutputDirectory() = classesDir + val maybeSourcePath = if extraFiles.isEmpty then Nil else List("-sourcepath", temp.getAbsolutePath.toString) + bridge.run( virtualSrcFiles, new TestDependencyChanges, - Array("-Yforce-sbt-phases", "-classpath", classesDirPath, "-usejavacp", "-d", classesDirPath), + Array("-Yforce-sbt-phases", "-classpath", classesDirPath, "-usejavacp", "-d", classesDirPath) ++ maybeSourcePath, output, analysisCallback, new TestReporter, diff --git a/sbt-bridge/test/xsbti/TestCompileProgress.scala b/sbt-bridge/test/xsbti/TestCompileProgress.scala index 9753a6e15b4c..d5dc81dfda24 100644 --- a/sbt-bridge/test/xsbti/TestCompileProgress.scala +++ b/sbt-bridge/test/xsbti/TestCompileProgress.scala @@ -8,9 +8,11 @@ class TestCompileProgress extends CompileProgress: class Run: private[TestCompileProgress] val _phases: mutable.Set[String] = mutable.LinkedHashSet.empty private[TestCompileProgress] val _unitPhases: mutable.Map[String, mutable.Set[String]] = mutable.LinkedHashMap.empty + private[TestCompileProgress] var _latestTotal: Int = 0 def phases: List[String] = _phases.toList def unitPhases: collection.MapView[String, List[String]] = _unitPhases.view.mapValues(_.toList) + def total: Int = _latestTotal private val _runs: mutable.ListBuffer[Run] = mutable.ListBuffer.empty private var _currentRun: Run = new Run @@ -27,4 +29,5 @@ class TestCompileProgress extends CompileProgress: override def advance(current: Int, total: Int, prevPhase: String, nextPhase: String): Boolean = _currentRun._phases += prevPhase _currentRun._phases += nextPhase + _currentRun._latestTotal = total true