Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Enums #4003

Merged
merged 17 commits into from
Feb 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 75 additions & 46 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ object desugar {
// ----- DerivedTypeTrees -----------------------------------

class SetterParamTree extends DerivedTypeTree {
def derivedType(sym: Symbol)(implicit ctx: Context) = sym.info.resultType
def derivedTree(sym: Symbol)(implicit ctx: Context) = tpd.TypeTree(sym.info.resultType)
}

class TypeRefTree extends DerivedTypeTree {
def derivedType(sym: Symbol)(implicit ctx: Context) = sym.typeRef
def derivedTree(sym: Symbol)(implicit ctx: Context) = tpd.TypeTree(sym.typeRef)
}

class TermRefTree extends DerivedTypeTree {
def derivedTree(sym: Symbol)(implicit ctx: Context) = tpd.ref(sym)
}

/** A type tree that computes its type from an existing parameter.
Expand Down Expand Up @@ -73,7 +77,7 @@ object desugar {
*
* parameter name == reference name ++ suffix
*/
def derivedType(sym: Symbol)(implicit ctx: Context) = {
def derivedTree(sym: Symbol)(implicit ctx: Context) = {
val relocate = new TypeMap {
val originalOwner = sym.owner
def apply(tp: Type) = tp match {
Expand All @@ -91,7 +95,7 @@ object desugar {
mapOver(tp)
}
}
relocate(sym.info)
tpd.TypeTree(relocate(sym.info))
}
}

Expand Down Expand Up @@ -301,34 +305,56 @@ object desugar {
val isCaseObject = mods.is(Case) && mods.is(Module)
val isImplicit = mods.is(Implicit)
val isEnum = mods.hasMod[Mod.Enum] && !mods.is(Module)
val isEnumCase = isLegalEnumCase(cdef)
val isEnumCase = mods.hasMod[Mod.EnumCase]
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.

// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.

val originalTparams = constr1.tparams
val originalVparamss = constr1.vparamss
val constrTparams = originalTparams.map(toDefParam)
lazy val derivedEnumParams = enumClass.typeParams.map(derivedTypeParam)
val impliedTparams =
if (isEnumCase && originalTparams.isEmpty)
derivedEnumParams.map(tdef => tdef.withFlags(tdef.mods.flags | PrivateLocal))
else
originalTparams
val constrTparams = impliedTparams.map(toDefParam)
val constrVparamss =
if (originalVparamss.isEmpty) { // ensure parameter list is non-empty
if (isCaseClass) ctx.error(CaseClassMissingParamList(cdef), cdef.namePos)
if (isCaseClass && originalTparams.isEmpty)
ctx.error(CaseClassMissingParamList(cdef), cdef.namePos)
ListOfNil
}
else originalVparamss.nestedMap(toDefParam)
val constr = cpy.DefDef(constr1)(tparams = constrTparams, vparamss = constrVparamss)

// Add constructor type parameters and evidence implicit parameters
// to auxiliary constructors
val normalizedBody = impl.body map {
case ddef: DefDef if ddef.name.isConstructorName =>
decompose(
defDef(
addEvidenceParams(
cpy.DefDef(ddef)(tparams = constrTparams),
evidenceParams(constr1).map(toDefParam))))
case stat =>
stat
val (normalizedBody, enumCases, enumCompanionRef) = {
// Add constructor type parameters and evidence implicit parameters
// to auxiliary constructors; set defaultGetters as a side effect.
def expandConstructor(tree: Tree) = tree match {
case ddef: DefDef if ddef.name.isConstructorName =>
decompose(
defDef(
addEvidenceParams(
cpy.DefDef(ddef)(tparams = constrTparams),
evidenceParams(constr1).map(toDefParam))))
case stat =>
stat
}
// The Identifiers defined by a case
def caseIds(tree: Tree) = tree match {
case tree: MemberDef => Ident(tree.name.toTermName) :: Nil
case PatDef(_, ids, _, _) => ids
}
val stats = impl.body.map(expandConstructor)
if (isEnum) {
val (enumCases, enumStats) = stats.partition(DesugarEnums.isEnumCase)
val enumCompanionRef = new TermRefTree()
val enumImport = Import(enumCompanionRef, enumCases.flatMap(caseIds))
(enumImport :: enumStats, enumCases, enumCompanionRef)
}
else (stats, Nil, EmptyTree)
}

def anyRef = ref(defn.AnyRefAlias.typeRef)

val derivedTparams = constrTparams.map(derivedTypeParam(_))
Expand Down Expand Up @@ -361,20 +387,16 @@ object desugar {
val classTypeRef = appliedRef(classTycon)

// a reference to `enumClass`, with type parameters coming from the case constructor
lazy val enumClassTypeRef = enumClass.primaryConstructor.info match {
case info: PolyType =>
if (constrTparams.isEmpty)
interpolatedEnumParent(cdef.pos.startPos)
else if ((constrTparams.corresponds(info.paramNames))((param, name) => param.name == name))
appliedRef(enumClassRef)
else {
ctx.error(i"explicit extends clause needed because type parameters of case and enum class differ"
, cdef.pos.startPos)
appliedTypeTree(enumClassRef, constrTparams map (_ => anyRef))
}
case _ =>
lazy val enumClassTypeRef =
if (enumClass.typeParams.isEmpty)
enumClassRef
}
else if (originalTparams.isEmpty)
appliedRef(enumClassRef)
else {
ctx.error(i"explicit extends clause needed because both enum case and enum class have type parameters"
, cdef.pos.startPos)
appliedTypeTree(enumClassRef, constrTparams map (_ => anyRef))
}

// new C[Ts](paramss)
lazy val creatorExpr = New(classTypeRef, constrVparamss nestedMap refOfDef)
Expand Down Expand Up @@ -428,6 +450,7 @@ object desugar {
}

// Case classes and case objects get Product parents
// Enum cases get an inferred parent if no parents are given
var parents1 = parents
if (isEnumCase && parents.isEmpty)
parents1 = enumClassTypeRef :: Nil
Expand Down Expand Up @@ -473,7 +496,7 @@ object desugar {
.withMods(companionMods | Synthetic))
.withPos(cdef.pos).toList

val companionMeths = defaultGetters ::: eqInstances
val companionMembers = defaultGetters ::: eqInstances ::: enumCases

// The companion object definitions, if a companion is needed, Nil otherwise.
// companion definitions include:
Expand All @@ -486,18 +509,17 @@ object desugar {
// For all other classes, the parent is AnyRef.
val companions =
if (isCaseClass) {
// The return type of the `apply` method
// The return type of the `apply` method, and an (empty or singleton) list
// of widening coercions
val (applyResultTpt, widenDefs) =
if (!isEnumCase)
(TypeTree(), Nil)
else if (parents.isEmpty || enumClass.typeParams.isEmpty)
(enumClassTypeRef, Nil)
else {
val tparams = enumClass.typeParams.map(derivedTypeParam)
enumApplyResult(cdef, parents, tparams, appliedRef(enumClassRef, tparams))
}
else
enumApplyResult(cdef, parents, derivedEnumParams, appliedRef(enumClassRef, derivedEnumParams))

val parent =
val companionParent =
if (constrTparams.nonEmpty ||
constrVparamss.length > 1 ||
mods.is(Abstract) ||
Expand All @@ -519,10 +541,10 @@ object desugar {
DefDef(nme.unapply, derivedTparams, (unapplyParam :: Nil) :: Nil, TypeTree(), unapplyRHS)
.withMods(synthetic)
}
companionDefs(parent, applyMeths ::: unapplyMeth :: companionMeths)
companionDefs(companionParent, applyMeths ::: unapplyMeth :: companionMembers)
}
else if (companionMeths.nonEmpty)
companionDefs(anyRef, companionMeths)
else if (companionMembers.nonEmpty)
companionDefs(anyRef, companionMembers)
else if (isValueClass) {
constr0.vparamss match {
case (_ :: Nil) :: _ => companionDefs(anyRef, Nil)
Expand All @@ -531,6 +553,13 @@ object desugar {
}
else Nil

enumCompanionRef match {
case ref: TermRefTree => // have the enum import watch the companion object
val (modVal: ValDef) :: _ = companions
ref.watching(modVal)
case _ =>
}

// For an implicit class C[Ts](p11: T11, ..., p1N: T1N) ... (pM1: TM1, .., pMN: TMN), the method
// synthetic implicit C[Ts](p11: T11, ..., p1N: T1N) ... (pM1: TM1, ..., pMN: TMN): C[Ts] =
// new C[Ts](p11, ..., p1N) ... (pM1, ..., pMN) =
Expand Down Expand Up @@ -563,7 +592,7 @@ object desugar {
}

val cdef1 = addEnumFlags {
val originalTparamsIt = originalTparams.toIterator
val originalTparamsIt = impliedTparams.toIterator
val originalVparamsIt = originalVparamss.toIterator.flatten
val tparamAccessors = derivedTparams.map(_.withMods(originalTparamsIt.next().mods))
val caseAccessor = if (isCaseClass) CaseAccessor else EmptyFlags
Expand Down Expand Up @@ -603,7 +632,7 @@ object desugar {
val moduleName = checkNotReservedName(mdef).asTermName
val impl = mdef.impl
val mods = mdef.mods
lazy val isEnumCase = isLegalEnumCase(mdef)
lazy val isEnumCase = mods.hasMod[Mod.EnumCase]
if (mods is Package)
PackageDef(Ident(moduleName), cpy.ModuleDef(mdef)(nme.PACKAGE, impl).withMods(mods &~ Package) :: Nil)
else if (isEnumCase)
Expand Down Expand Up @@ -650,7 +679,7 @@ object desugar {
*/
def patDef(pdef: PatDef)(implicit ctx: Context): Tree = flatTree {
val PatDef(mods, pats, tpt, rhs) = pdef
if (mods.hasMod[Mod.EnumCase] && enumCaseIsLegal(pdef))
if (mods.hasMod[Mod.EnumCase])
pats map {
case id: Ident =>
expandSimpleEnumCase(id.name.asTermName, mods,
Expand Down
46 changes: 21 additions & 25 deletions compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import core._
import util.Positions._, Types._, Contexts._, Constants._, Names._, NameOps._, Flags._
import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._
import Decorators._
import reporting.diagnostic.messages.EnumCaseDefinitionInNonEnumOwner
import collection.mutable.ListBuffer
import util.Property
import typer.ErrorReporting._
Expand All @@ -23,20 +22,21 @@ object DesugarEnums {
/** Attachment containing the number of enum cases and the smallest kind that was seen so far. */
val EnumCaseCount = new Property.Key[(Int, CaseKind.Value)]

/** the enumeration class that is a companion of the current object */
def enumClass(implicit ctx: Context) = ctx.owner.linkedClass

/** Is this an enum case that's situated in a companion object of an enum class? */
def isLegalEnumCase(tree: MemberDef)(implicit ctx: Context): Boolean =
tree.mods.hasMod[Mod.EnumCase] && enumCaseIsLegal(tree)
/** The enumeration class that belongs to an enum case. This works no matter
* whether the case is still in the enum class or it has been transferred to the
* companion object.
*/
def enumClass(implicit ctx: Context): Symbol = {
val cls = ctx.owner
if (cls.is(Module)) cls.linkedClass else cls
}

/** Is enum case `tree` situated in a companion object of an enum class? */
def enumCaseIsLegal(tree: Tree)(implicit ctx: Context): Boolean = (
ctx.owner.is(ModuleClass) && enumClass.derivesFrom(defn.EnumClass)
|| { ctx.error(EnumCaseDefinitionInNonEnumOwner(ctx.owner), tree.pos)
false
}
)
/** Is `tree` an (untyped) enum case? */
def isEnumCase(tree: Tree)(implicit ctx: Context): Boolean = tree match {
case tree: MemberDef => tree.mods.hasMod[Mod.EnumCase]
case PatDef(mods, _, _, _) => mods.hasMod[Mod.EnumCase]
case _ => false
}

/** A reference to the enum class `E`, possibly followed by type arguments.
* Each covariant type parameter is approximated by its lower bound.
Expand Down Expand Up @@ -68,8 +68,8 @@ object DesugarEnums {

/** Add implied flags to an enum class or an enum case */
def addEnumFlags(cdef: TypeDef)(implicit ctx: Context) =
if (cdef.mods.hasMod[Mod.Enum]) cdef.withFlags(cdef.mods.flags | Abstract | Sealed)
else if (isLegalEnumCase(cdef)) cdef.withFlags(cdef.mods.flags | Final)
if (cdef.mods.hasMod[Mod.Enum]) cdef.withMods(cdef.mods.withFlags(cdef.mods.flags | Abstract | Sealed))
else if (isEnumCase(cdef)) cdef.withMods(cdef.mods.withFlags(cdef.mods.flags | Final))
else cdef

private def valuesDot(name: String) = Select(Ident(nme.DOLLAR_VALUES), name.toTermName)
Expand Down Expand Up @@ -193,24 +193,20 @@ object DesugarEnums {
}

/** Expand a module definition representing a parameterless enum case */
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, pos: Position)(implicit ctx: Context): Tree =
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, pos: Position)(implicit ctx: Context): Tree = {
assert(impl.body.isEmpty)
if (impl.parents.isEmpty)
if (impl.body.isEmpty)
expandSimpleEnumCase(name, mods, pos)
else {
val parent = interpolatedEnumParent(pos)
expandEnumModule(name, cpy.Template(impl)(parents = parent :: Nil), mods, pos)
}
expandSimpleEnumCase(name, mods, pos)
else {
def toStringMeth =
DefDef(nme.toString_, Nil, Nil, TypeTree(defn.StringType), Literal(Constant(name.toString)))
.withFlags(Override)
val (tagMeth, scaffolding) = enumTagMeth(CaseKind.Object)
val impl1 = cpy.Template(impl)(body =
impl.body ++ List(tagMeth, toStringMeth) ++ registerCall)
val impl1 = cpy.Template(impl)(body = List(tagMeth, toStringMeth) ++ registerCall)
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods | Final)
flatTree(scaffolding ::: vdef :: Nil).withPos(pos)
}
}

/** Expand a simple enum case */
def expandSimpleEnumCase(name: TermName, mods: Modifiers, pos: Position)(implicit ctx: Context): Tree =
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
case NoPrefix =>
true
case pre: ThisType =>
tp.isType ||
pre.cls.isStaticOwner ||
tp.symbol.isParamOrAccessor && !pre.cls.is(Trait) && ctx.owner.enclosingClass == pre.cls
// was ctx.owner.enclosingClass.derivesFrom(pre.cls) which was not tight enough
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
*/
def ensureCompletions(implicit ctx: Context): Unit = ()

/** The method that computes the type of this tree */
def derivedType(originalSym: Symbol)(implicit ctx: Context): Type
/** The method that computes the tree with the derived type */
def derivedTree(originalSym: Symbol)(implicit ctx: Context): tpd.Tree
}

/** Property key containing TypeTrees whose type is computed
Expand Down
Loading