Skip to content

Commit

Permalink
Implement pickling/unpickling for dependent refinements
Browse files Browse the repository at this point in the history
Needs an addition to Tasty format: TRACKED as a modifier.
  • Loading branch information
odersky committed Nov 21, 2023
1 parent a4656b1 commit 30c78c1
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 47 deletions.
42 changes: 42 additions & 0 deletions compiler/src/dotty/tools/dotc/core/NamerOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package core
import Contexts.*, Symbols.*, Types.*, Flags.*, Scopes.*, Decorators.*, Names.*, NameOps.*
import SymDenotations.{LazyType, SymDenotation}, StdNames.nme
import TypeApplications.EtaExpansion
import collection.mutable

/** Operations that are shared between Namer and TreeUnpickler */
object NamerOps:
Expand All @@ -24,6 +25,47 @@ object NamerOps:
resType = RefinedType(resType, param.name, param.termRef)
resType

/** Split dependent class refinements off parent type and add them to `refinements` */
extension (tp: Type)
def separateRefinements(refinements: mutable.LinkedHashMap[Name, Type])(using Context): Type =
tp match
case RefinedType(tp1, rname, rinfo) =>
try tp1.separateRefinements(refinements)
finally
refinements(rname) = refinements.get(rname) match
case Some(tp) => tp & rinfo
case None => rinfo
case tp => tp

/** Add all parent `refinements` to the result type of the info of the dependent
* class constructor `constr`. Parent refinements refer to parameter accessors
* in the current class. These have to be mapped to the paramRefs of the
* constructor info.
*/
def integrateParentRefinements(
constr: Symbol, refinements: mutable.LinkedHashMap[Name, Type])(using Context): Unit =

/** @param info the (remaining part) of the constructor info
* @param nameToParamRef the map from parameter names to paramRefs of
* previously encountered parts of `info`.
*/
def recur(info: Type, nameToParamRef: mutable.Map[Name, Type]): Type = info match
case info: MethodOrPoly =>
info.derivedLambdaType(resType =
recur(info.resType, nameToParamRef ++= info.paramNames.zip(info.paramRefs)))
case _ =>
val mapParams = new TypeMap:
def apply(t: Type) = t match
case t: TermRef if t.symbol.is(ParamAccessor) && t.symbol.owner == constr.owner =>
nameToParamRef(t.name)
case _ =>
mapOver(t)
refinements.foldLeft(info): (info, refinement) =>
val (rname, rinfo) = refinement
RefinedType(info, rname, mapParams(rinfo))
constr.info = recur(constr.info, mutable.Map())
end integrateParentRefinements

/** If isConstructor, make sure it has at least one non-implicit parameter list
* This is done by adding a () in front of a leading old style implicit parameter,
* or by adding a () as last -- or only -- parameter list if the constructor has
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,7 @@ class TreePickler(pickler: TastyPickler) {
if (flags.is(Exported)) writeModTag(EXPORTED)
if (flags.is(Given)) writeModTag(GIVEN)
if (flags.is(Implicit)) writeModTag(IMPLICIT)
if (flags.is(Tracked)) writeModTag(TRACKED)
if (isTerm) {
if (flags.is(Lazy, butNot = Module)) writeModTag(LAZY)
if (flags.is(AbsOverride)) { writeModTag(ABSTRACT); writeModTag(OVERRIDE) }
Expand All @@ -787,7 +788,6 @@ class TreePickler(pickler: TastyPickler) {
if (flags.is(Extension)) writeModTag(EXTENSION)
if (flags.is(ParamAccessor)) writeModTag(PARAMsetter)
if (flags.is(SuperParamAlias)) writeModTag(PARAMalias)
if (flags.is(Tracked)) writeModTag(TRACKED)
assert(!(flags.is(Label)))
}
else {
Expand Down
27 changes: 19 additions & 8 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1012,12 +1012,20 @@ class TreeUnpickler(reader: TastyReader,
* but skip constructor arguments. Return any trees that were partially
* parsed in this way as InferredTypeTrees.
*/
def readParents(withArgs: Boolean)(using Context): List[Tree] =
def readParents(cls: ClassSymbol, withArgs: Boolean)(using Context): List[Tree] =
collectWhile(nextByte != SELFDEF && nextByte != DEFDEF) {
nextUnsharedTag match
case APPLY | TYPEAPPLY | BLOCK =>
if withArgs then readTree()
else InferredTypeTree().withType(readParentType())
if withArgs then
readTree()
else if cls.is(Dependent) then
val parentReader = fork
val parentCoreType = readParentType()
if parentCoreType.dealias.typeSymbol.is(Dependent)
then parentReader.readTree() // read the whole tree since we need to see the refinement
else InferredTypeTree().withType(parentCoreType)
else
InferredTypeTree().withType(readParentType())
case _ => readTpt()
}

Expand All @@ -1043,9 +1051,10 @@ class TreeUnpickler(reader: TastyReader,
while (bodyIndexer.reader.nextByte != DEFDEF) bodyIndexer.skipTree()
bodyIndexer.indexStats(end)
}
val parentReader = fork
val parents = readParents(withArgs = false)(using parentCtx)
val parentTypes = parents.map(_.tpe.dealias)
val parentsReader = fork
val parents = readParents(cls, withArgs = false)(using parentCtx)
val parentRefinements = mutable.LinkedHashMap[Name, Type]()
val parentTypes = parents.map(_.tpe.dealias.separateRefinements(parentRefinements))
val self =
if (nextByte == SELFDEF) {
readByte()
Expand All @@ -1058,11 +1067,13 @@ class TreeUnpickler(reader: TastyReader,
selfInfo = if (self.isEmpty) NoType else self.tpt.tpe
).integrateOpaqueMembers
val constr = readIndexedDef().asInstanceOf[DefDef]
if parentRefinements.nonEmpty then
integrateParentRefinements(constr.symbol, parentRefinements)
val mappedParents: LazyTreeList =
if parents.exists(_.isInstanceOf[InferredTypeTree]) then
// parents were not read fully, will need to be read again later on demand
new LazyReader(parentReader, localDummy, ctx.mode, ctx.source,
_.readParents(withArgs = true)
new LazyReader(parentsReader, localDummy, ctx.mode, ctx.source,
_.readParents(cls, withArgs = true)
.map(_.changeOwner(localDummy, constr.symbol)))
else parents

Expand Down
40 changes: 2 additions & 38 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,6 @@ class Namer { typer: Typer =>
if ptype.typeParams.isEmpty
//&& !ptype.dealias.typeSymbol.primaryConstructor.info.finalResultType.isInstanceOf[RefinedType]
&& !ptype.dealias.typeSymbol.is(Dependent)
|| ctx.erasedTypes
then
ptype
else
Expand Down Expand Up @@ -1613,46 +1612,12 @@ class Namer { typer: Typer =>
/** The refinements coming from all parent class constructor applications */
val parentRefinements = mutable.LinkedHashMap[Name, Type]()

/** Split refinements off parent type and add them to `parentRefinements` */
def separateRefinements(tp: Type): Type = tp match
case RefinedType(tp1, rname, rinfo) =>
try separateRefinements(tp1)
finally
parentRefinements(rname) = parentRefinements.get(rname) match
case Some(tp) => tp & rinfo
case None => rinfo
case tp => tp

/** Add all parent refinements to the result type of the `info` of
* the class constructor. Parent refinements refer to parameter accessors
* in the current class. These have to be mapped to the paramRefs of the
* constructor info.
* @param info The (remaining part) of the constructor info
* @param nameToParamRef The map from parameter names to paramRefs of
* previously encountered parts of `info`.
*/
def integrateParentRefinements(info: Type, nameToParamRef: Map[Name, Type]): Type = info match
case info: MethodOrPoly =>
info.derivedLambdaType(resType =
integrateParentRefinements(info.resType,
nameToParamRef ++ info.paramNames.zip(info.paramRefs)))
case _ =>
val mapParams = new TypeMap:
def apply(t: Type) = t match
case t: TermRef if t.symbol.is(ParamAccessor) && t.symbol.owner == cls =>
nameToParamRef(t.name)
case _ =>
mapOver(t)
parentRefinements.foldLeft(info): (info, refinement) =>
val (rname, rinfo) = refinement
RefinedType(info, rname, mapParams(rinfo))

val parentTypes =
defn.adjustForTuple(cls, cls.typeParams,
defn.adjustForBoxedUnit(cls,
addUsingTraits(
ensureFirstIsClass(cls, parents.map(checkedParentType(_)))
))).map(separateRefinements)
))).map(_.separateRefinements(parentRefinements))

typr.println(i"completing $denot, parents = $parents%, %, stripped parent types = $parentTypes%, %")
typr.println(i"constr type = ${cls.primaryConstructor.infoOrCompleter}, refinements = ${parentRefinements.toList}")
Expand All @@ -1671,8 +1636,7 @@ class Namer { typer: Typer =>
tempInfo = null // The temporary info can now be garbage-collected

if parentRefinements.nonEmpty then
val constr = cls.primaryConstructor
constr.info = integrateParentRefinements(constr.info, Map())
integrateParentRefinements(cls.primaryConstructor, parentRefinements)
cls.setFlag(Dependent)
Checking.checkWellFormed(cls)
if (isDerivedValueClass(cls)) cls.setFlag(Final)
Expand Down
1 change: 1 addition & 0 deletions tasty/src/dotty/tools/tasty/TastyFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ Standard-Section: "ASTs" TopLevelStat*
EXPORTED -- An export forwarder
OPEN -- an open class
INVISIBLE -- invisible during typechecking
TRACKED -- a tracked class parameter / a dependent class
Annotation
Variance = STABLE -- invariant
Expand Down
4 changes: 4 additions & 0 deletions tests/neg/i3964.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ object Test1:
trait Foo { val x: Animal }
val foo: Foo { val x: Cat } = new Foo { val x = new Cat } // error, but should work

object Test3:
trait Vec(tracked val size: Int)
class Vec8 extends Vec(8):
val s: 8 = size // error, but should work

0 comments on commit 30c78c1

Please sign in to comment.