Skip to content

Commit

Permalink
Merge pull request #10389 from lrytz/changeOwner-invalidate-caches
Browse files Browse the repository at this point in the history
Invalidate type caches in ChangeOwnerTraverser
  • Loading branch information
lrytz committed Jul 5, 2023
2 parents 9e7c2d5 + 46af8c7 commit e4fca6e
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/compiler/scala/tools/nsc/typechecker/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ trait Infer extends Checkable {
}
}
tvars foreach instantiateTypeVar
invalidateTreeTpeCaches(tree0, tvars.map(_.origin.typeSymbol))
invalidateTreeTpeCaches(tree0, tvars.map(_.origin.typeSymbol).toSet)
}
/* If the scrutinee has free type parameters but the pattern does not,
* we have to flip the arguments so the expected type is treated as more
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/scala/tools/nsc/typechecker/Namers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ trait Namers extends MethodSynthesis {
val newFlags = (sym.flags & LOCKED) | flags
// !!! needed for: pos/t5954d; the uniques type cache will happily serve up the same TypeRef
// over this mutated symbol, and we witness a stale cache for `parents`.
invalidateCaches(sym.rawInfo, sym :: sym.moduleClass :: Nil)
invalidateCaches(sym.rawInfo, Set(sym, sym.moduleClass))
sym reset NoType setFlag newFlags setPos pos
sym.moduleClass andAlso (updatePosFlags(_, pos, moduleClassFlags(flags)))

Expand Down
2 changes: 1 addition & 1 deletion src/reflect/scala/reflect/internal/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1873,7 +1873,7 @@ trait Symbols extends api.Symbols { self: SymbolTable =>
info match {
case ci @ ClassInfoType(_, _, _) =>
setInfo(ci.copy(parents = ci.parents :+ SerializableTpe))
invalidateCaches(ci.typeSymbol.typeOfThis, ci.typeSymbol :: Nil)
invalidateCaches(ci.typeSymbol.typeOfThis, Set(ci.typeSymbol))
case i =>
abort("Only ClassInfoTypes can be made serializable: "+ i)
}
Expand Down
31 changes: 26 additions & 5 deletions src/reflect/scala/reflect/internal/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1621,13 +1621,29 @@ trait Trees extends api.Trees {
}

class ChangeOwnerTraverser(val oldowner: Symbol, val newowner: Symbol) extends InternalTraverser {
final def change(sym: Symbol) = {
protected val changedSymbols = mutable.Set.empty[Symbol]
protected val treeTypes = mutable.Set.empty[Type]

def change(sym: Symbol) = {
if (sym != NoSymbol && sym.owner == oldowner) {
sym.owner = newowner
if (sym.isModule) sym.moduleClass.owner = newowner
changedSymbols += sym
if (sym.isModule) {
sym.moduleClass.owner = newowner
changedSymbols += sym.moduleClass
}
}
}

override def apply[T <: Tree](tree: T): T = {
traverse(tree)
if (changedSymbols.nonEmpty)
new InvalidateTypeCaches(changedSymbols).invalidate(treeTypes)
tree
}

override def traverse(tree: Tree): Unit = {
if (tree.tpe != null) treeTypes += tree.tpe
tree match {
case _: Return =>
if (tree.symbol == oldowner) {
Expand Down Expand Up @@ -1759,7 +1775,10 @@ trait Trees extends api.Trees {
*/
class TreeSymSubstituter(from: List[Symbol], to: List[Symbol]) extends InternalTransformer {
val symSubst = SubstSymMap(from, to)
private[this] var mutatedSymbols: List[Symbol] = Nil

protected val changedSymbols = mutable.Set.empty[Symbol]
protected val treeTypes = mutable.Set.empty[Type]

override def transform(tree: Tree): Tree = {
@tailrec
def subst(from: List[Symbol], to: List[Symbol]): Unit = {
Expand All @@ -1768,6 +1787,7 @@ trait Trees extends api.Trees {
else subst(from.tail, to.tail)
}
tree modifyType symSubst
if (tree.tpe != null) treeTypes += tree.tpe

if (tree.hasSymbolField) {
subst(from, to)
Expand All @@ -1780,7 +1800,7 @@ trait Trees extends api.Trees {
|TreeSymSubstituter: updated info of symbol ${sym}
| Old: ${showRaw(sym.info, printTypes = true, printIds = true)}
| New: ${showRaw(newInfo, printTypes = true, printIds = true)}""")
mutatedSymbols ::= sym
changedSymbols += sym
sym updateInfo newInfo
}
}
Expand All @@ -1805,7 +1825,8 @@ trait Trees extends api.Trees {
}
def apply[T <: Tree](tree: T): T = {
val tree1 = transform(tree)
invalidateTreeTpeCaches(tree1, mutatedSymbols)
if (changedSymbols.nonEmpty)
new InvalidateTypeCaches(changedSymbols).invalidate(treeTypes)
tree1.asInstanceOf[T]
}
override def toString() = "TreeSymSubstituter/" + substituterString("Symbol", "Symbol", from, to)
Expand Down
71 changes: 58 additions & 13 deletions src/reflect/scala/reflect/internal/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5265,24 +5265,69 @@ trait Types
*/
def importableMembers(pre: Type): Scope = pre.members filter isImportable

def invalidateTreeTpeCaches(tree: Tree, updatedSyms: List[Symbol]) = if (!updatedSyms.isEmpty)
def invalidateTreeTpeCaches(tree: Tree, updatedSyms: collection.Set[Symbol]) = if (!updatedSyms.isEmpty) {
val invldtr = new InvalidateTypeCaches(updatedSyms)
for (t <- tree if t.tpe != null)
for (tp <- t.tpe) {
invalidateCaches(tp, updatedSyms)
}
invldtr.invalidate(t.tpe)
}

def invalidateCaches(t: Type, updatedSyms: collection.Set[Symbol]): Unit =
new InvalidateTypeCaches(updatedSyms).invalidate(t)

class InvalidateTypeCaches(changedSymbols: collection.Set[Symbol]) extends TypeFolder {
private var res = false
private val seen = new java.util.IdentityHashMap[Type, Boolean]

def invalidate(tps: Iterable[Type]): Unit = {
res = false
seen.clear()
try tps.foreach(invalidateImpl)
finally seen.clear()
}

def invalidate(tp: Type): Unit = invalidate(List(tp))

protected def invalidateImpl(tp: Type): Boolean = Option(seen.get(tp)).getOrElse {
val saved = res
try {
apply(tp)
res
} finally res = saved
}

def apply(tp: Type): Unit = tp match {
case _ if seen.containsKey(tp) =>

case tr: TypeRef =>
val preInvalid = invalidateImpl(tr.pre)
var argsInvalid = false
tr.args.foreach(arg => argsInvalid = invalidateImpl(arg) || argsInvalid)
if (preInvalid || argsInvalid || changedSymbols(tr.sym)) {
tr.invalidateTypeRefCaches()
res = true
}
seen.put(tp, res)

case ct: CompoundType if ct.baseClasses.exists(changedSymbols) =>
ct.invalidatedCompoundTypeCaches()
res = true
seen.put(tp, res)

def invalidateCaches(t: Type, updatedSyms: List[Symbol]): Unit =
t match {
case tr: TypeRef if updatedSyms.contains(tr.sym) => tr.invalidateTypeRefCaches()
case ct: CompoundType if ct.baseClasses.exists(updatedSyms.contains) => ct.invalidatedCompoundTypeCaches()
case st: SingleType =>
if (updatedSyms.contains(st.sym)) st.invalidateSingleTypeCaches()
val underlying = st.underlying
if (underlying ne st)
invalidateCaches(underlying, updatedSyms)
val preInvalid = invalidateImpl(st.pre)
if (preInvalid || changedSymbols(st.sym)) {
st.invalidateSingleTypeCaches()
res = true
}
val underInvalid = (st.underlying ne st) && invalidateImpl(st.underlying)
res ||= underInvalid
seen.put(tp, res)

case _ =>
tp.foldOver(this)
seen.put(tp, res)
}

}

val shorthands = Set(
"scala.collection.immutable.List",
Expand Down

0 comments on commit e4fca6e

Please sign in to comment.