Skip to content

Commit

Permalink
Make incremental compilation aware of synthesized mirrors
Browse files Browse the repository at this point in the history
A product mirror needs to be resynthesized if any class parameter changes, and
a sum mirror needs to be resynthesized if any child of the sealed type changes,
but previously this did not reliably work because the dependency recording in
ExtractDependencies was unaware of mirrors.

Instead of making ExtractDependencies aware of mirrors, we solve this by
directly recording the dependencies when the mirror is synthesized, this way we
can be sure to always correctly invalidate users of mirrors, even if the
synthesized mirror type is not present in the AST at phase ExtractDependencies.

This is the first time that we record dependencies outside of the
ExtractDependencies phase, in the future we should see if we can extend this
mechanism to record more dependencies during typechecking to make incremental
compilation more robust (e.g. by keeping track of symbols looked up by macros).

Eventually, we might even want to completely get rid of the ExtractDependencies
phase and record all dependencies on the fly if it turns out to be faster.
  • Loading branch information
smarter committed Jul 29, 2023
1 parent 54e2f59 commit 3a2141b
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 16 deletions.
9 changes: 9 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ object Contexts {
val local = incCallback
local != null && local.enabled || forceRun

/** Used to record dependencies to invalidate during incremental compilation.
* This is only used if `runZincPhases` is true.
*/
def depRecorder: sbt.DependencyRecorder = base.depRecorder

/** The current plain printer */
def printerFn: Context => Printer = store(printerFnLoc)

Expand Down Expand Up @@ -1042,6 +1047,9 @@ object Contexts {
charArray = new Array[Char](charArray.length * 2)
charArray

// Incremental compilation state
private[Contexts] val depRecorder: sbt.DependencyRecorder = sbt.DependencyRecorder()

def reset(): Unit =
uniques.clear()
uniqueAppliedTypes.clear()
Expand All @@ -1053,6 +1061,7 @@ object Contexts {
sources.clear()
files.clear()
comparers.clear() // forces re-evaluation of top and bottom classes in TypeComparer
depRecorder.clear()

// Test that access is single threaded

Expand Down
33 changes: 18 additions & 15 deletions compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,12 @@ class ExtractDependencies extends Phase {

override def run(using Context): Unit = {
val unit = ctx.compilationUnit
val rec = DependencyRecorder()
val collector = ExtractDependenciesCollector(rec)
val collector = ExtractDependenciesCollector()
collector.traverse(unit.tpdTree)

if (ctx.settings.YdumpSbtInc.value) {
val deps = rec.classDependencies.map(_.toString).toArray[Object]
val names = rec.usedNames.map { case (clazz, names) => s"$clazz: $names" }.toArray[Object]
val deps = ctx.depRecorder.classDependencies.map(_.toString).toArray[Object]
val names = ctx.depRecorder.usedNames.map { case (clazz, names) => s"$clazz: $names" }.toArray[Object]
Arrays.sort(deps)
Arrays.sort(names)

Expand All @@ -92,7 +91,7 @@ class ExtractDependencies extends Phase {
} finally pw.close()
}

rec.sendToZinc()
ctx.depRecorder.sendToZinc()
}
}

Expand All @@ -116,32 +115,32 @@ object ExtractDependencies {
* specially, see the subsection "Dependencies introduced by member reference and
* inheritance" in the "Name hashing algorithm" section.
*/
private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd.TreeTraverser { thisTreeTraverser =>
private class ExtractDependenciesCollector() extends tpd.TreeTraverser { thisTreeTraverser =>
import tpd._

private def addMemberRefDependency(sym: Symbol)(using Context): Unit =
if (!ignoreDependency(sym)) {
rec.addUsedName(sym)
ctx.depRecorder.addUsedName(sym)
// packages have class symbol. Only record them as used names but not dependency
if (!sym.is(Package)) {
val enclOrModuleClass = if (sym.is(ModuleVal)) sym.moduleClass else sym.enclosingClass
assert(enclOrModuleClass.isClass, s"$enclOrModuleClass, $sym")

rec.addClassDependency(enclOrModuleClass, DependencyByMemberRef)
ctx.depRecorder.addClassDependency(enclOrModuleClass, DependencyByMemberRef)
}
}

private def addInheritanceDependencies(tree: Closure)(using Context): Unit =
// If the tpt is empty, this is a non-SAM lambda, so no need to register
// an inheritance relationship.
if !tree.tpt.isEmpty then
rec.addClassDependency(tree.tpt.tpe.classSymbol, LocalDependencyByInheritance)
ctx.depRecorder.addClassDependency(tree.tpt.tpe.classSymbol, LocalDependencyByInheritance)

private def addInheritanceDependencies(tree: Template)(using Context): Unit =
if (tree.parents.nonEmpty) {
val depContext = depContextOf(tree.symbol.owner)
for parent <- tree.parents do
rec.addClassDependency(parent.tpe.classSymbol, depContext)
ctx.depRecorder.addClassDependency(parent.tpe.classSymbol, depContext)
}

private def depContextOf(cls: Symbol)(using Context): DependencyContext =
Expand Down Expand Up @@ -179,7 +178,7 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd.
for sel <- selectors if !sel.isWildcard do
addImported(sel.name)
if sel.rename != sel.name then
rec.addUsedRawName(sel.rename)
ctx.depRecorder.addUsedRawName(sel.rename)
case exp @ Export(expr, selectors) =>
val dep = expr.tpe.classSymbol
if dep.exists && selectors.exists(_.isWildcard) then
Expand All @@ -192,7 +191,7 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd.
// inheritance dependency in the presence of wildcard exports
// to ensure all new members of `dep` are forwarded to.
val depContext = depContextOf(ctx.owner.lexicallyEnclosingClass)
rec.addClassDependency(dep, depContext)
ctx.depRecorder.addClassDependency(dep, depContext)
case t: TypeTree =>
addTypeDependency(t.tpe)
case ref: RefTree =>
Expand Down Expand Up @@ -299,7 +298,7 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd.
val traverser = new TypeDependencyTraverser {
def addDependency(symbol: Symbol) =
if (!ignoreDependency(symbol) && symbol.is(Sealed)) {
rec.addUsedName(symbol, includeSealedChildren = true)
ctx.depRecorder.addUsedName(symbol, includeSealedChildren = true)
}
}
traverser.traverse(tpe)
Expand Down Expand Up @@ -422,8 +421,12 @@ class DependencyRecorder {
case (usedName, scopes) =>
cb.usedName(className, usedName.toString, scopes)
classDependencies.foreach(recordClassDependency(cb, _))
_usedNames.clear()
_classDependencies.clear()
clear()

/** Clear all state. */
def clear(): Unit =
_usedNames.clear()
_classDependencies.clear()

/** Handles dependency on given symbol by trying to figure out if represents a term
* that is coming from either source code (not necessarily compiled in this compilation
Expand Down
17 changes: 16 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ import ast.Trees.genericEmptyTree
import annotation.{tailrec, constructorOnly}
import ast.tpd._
import Synthesizer._
import sbt.ExtractDependencies.*
import sbt.ClassDependency
import xsbti.api.DependencyContext._

/** Synthesize terms for special classes */
class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
Expand Down Expand Up @@ -458,7 +461,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
val reason = s"it reduces to a tuple with arity $arity, expected arity <= $maxArity"
withErrors(i"${defn.PairClass} is not a generic product because $reason")
case MirrorSource.ClassSymbol(pre, cls) =>
if cls.isGenericProduct then makeProductMirror(pre, cls, None)
if cls.isGenericProduct then
if ctx.runZincPhases then
// The mirror should be resynthesized if the constructor of the
// case class `cls` changes. See `sbt-test/source-dependencies/mirror-product`.
ctx.depRecorder.addClassDependency(cls, DependencyByMemberRef)
ctx.depRecorder.addUsedName(cls.primaryConstructor)
makeProductMirror(pre, cls, None)
else withErrors(i"$cls is not a generic product because ${cls.whyNotGenericProduct}")
case Left(msg) =>
withErrors(i"type `$mirroredType` is not a generic product because $msg")
Expand All @@ -478,6 +487,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
val clsIsGenericSum = cls.isGenericSum(pre)

if acceptableMsg.isEmpty && clsIsGenericSum then
if ctx.runZincPhases then
// The mirror should be resynthesized if any child of the sealed class
// `cls` changes. See `sbt-test/source-dependencies/mirror-sum`.
ctx.depRecorder.addClassDependency(cls, DependencyByMemberRef)
ctx.depRecorder.addUsedName(cls, includeSealedChildren = true)

val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))

def internalError(msg: => String)(using Context): Unit =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
case class MyProduct(x: Int)
10 changes: 10 additions & 0 deletions sbt-test/source-dependencies/mirror-product/Test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import scala.deriving.Mirror
import scala.compiletime.erasedValue

transparent inline def foo[T](using m: Mirror.Of[T]): Int =
inline erasedValue[m.MirroredElemTypes] match
case _: (Int *: EmptyTuple) => 1
case _: (Int *: String *: EmptyTuple) => 2

@main def Test =
assert(foo[MyProduct] == 2)
1 change: 1 addition & 0 deletions sbt-test/source-dependencies/mirror-product/build.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
scalaVersion := sys.props("plugin.scalaVersion")
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
case class MyProduct(x: Int, y: String)
7 changes: 7 additions & 0 deletions sbt-test/source-dependencies/mirror-product/test
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
> compile

# change the case class constructor
$ copy-file changes/MyProduct.scala MyProduct.scala

# Both MyProduct.scala and Test.scala should be recompiled, otherwise the assertion will fail
> run
2 changes: 2 additions & 0 deletions sbt-test/source-dependencies/mirror-sum/Sum.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
sealed trait Sum
case class Child1() extends Sum
12 changes: 12 additions & 0 deletions sbt-test/source-dependencies/mirror-sum/Test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import scala.deriving.Mirror
import scala.compiletime.erasedValue

object Test:
transparent inline def foo[T](using m: Mirror.Of[T]): Int =
inline erasedValue[m.MirroredElemLabels] match
case _: ("Child1" *: EmptyTuple) => 1
case _: ("Child1" *: "Child2" *: EmptyTuple) => 2

def main(args: Array[String]): Unit =
assert(foo[Sum] == 2)

4 changes: 4 additions & 0 deletions sbt-test/source-dependencies/mirror-sum/build.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
scalaVersion := sys.props("plugin.scalaVersion")
// Use more precise invalidation, otherwise the reference to `Sum` in
// Test.scala is enough to invalidate it when a child is added.
ThisBuild / incOptions ~= { _.withUseOptimizedSealed(true) }
3 changes: 3 additions & 0 deletions sbt-test/source-dependencies/mirror-sum/changes/Sum.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
sealed trait Sum
case class Child1() extends Sum
case class Child2() extends Sum
7 changes: 7 additions & 0 deletions sbt-test/source-dependencies/mirror-sum/test
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
> compile

# Add a child
$ copy-file changes/Sum.scala Sum.scala

# Both Sum.scala and Test.scala should be recompiled, otherwise the assertion will fail
> run

0 comments on commit 3a2141b

Please sign in to comment.