Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mini-phase to fix constructors for enums extending java.lang.Enum #6602

Merged
merged 13 commits into from
Jun 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class Compiler {
List(new FirstTransform, // Some transformations to put trees into a canonical form
new CheckReentrant, // Internal use only: Check that compiled program has no data races involving global vars
new ElimPackagePrefixes, // Eliminate references to package prefixes in Select nodes
new CookComments) :: // Cook the comments: expand variables, doc, etc.
new CookComments, // Cook the comments: expand variables, doc, etc.
new CompleteJavaEnums) :: // Fill in constructors for Java enums
List(new CheckStatic, // Check restrictions that apply to @static members
new ElimRepeated, // Rewrite vararg parameters and arguments
new ExpandSAMs, // Expand single abstract method closures to anonymous classes
Expand Down
9 changes: 9 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,15 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
loop(tree, Nil, Nil)
}

/** Decompose a template body into parameters and other statements */
def decomposeTemplateBody(body: List[Tree])(implicit ctx: Context): (List[Tree], List[Tree]) =
body.partition {
case stat: TypeDef => stat.symbol is Flags.Param
case stat: ValOrDefDef =>
stat.symbol.is(Flags.ParamAccessor) && !stat.symbol.isSetter
case _ => false
}

/** An extractor for closures, either contained in a block or standalone.
*/
object closure {
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ class Definitions {
val companion = JavaLangPackageVal.info.decl(nme.Object).symbol
companion.moduleClass.info = NoType // to indicate that it does not really exist
companion.info = NoType // to indicate that it does not really exist

completeClass(cls)
}
def ObjectType: TypeRef = ObjectClass.typeRef
Expand Down Expand Up @@ -674,6 +673,8 @@ class Definitions {
def NoneClass(implicit ctx: Context): ClassSymbol = NoneModuleRef.symbol.moduleClass.asClass
lazy val EnumType: TypeRef = ctx.requiredClassRef("scala.Enum")
def EnumClass(implicit ctx: Context): ClassSymbol = EnumType.symbol.asClass
lazy val JEnumType: TypeRef = ctx.requiredClassRef("scala.compat.JEnum")
def JEnumClass(implicit ctx: Context): ClassSymbol = JEnumType.symbol.asClass
lazy val EnumValuesType: TypeRef = ctx.requiredClassRef("scala.runtime.EnumValues")
def EnumValuesClass(implicit ctx: Context): ClassSymbol = EnumValuesType.symbol.asClass
lazy val ProductType: TypeRef = ctx.requiredClassRef("scala.Product")
Expand Down
7 changes: 5 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Denotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -814,8 +814,11 @@ object Denotations {
def invalidateInheritedInfo(): Unit = ()

private def updateValidity()(implicit ctx: Context): this.type = {
assert(ctx.runId >= validFor.runId || ctx.settings.YtestPickler.value, // mixing test pickler with debug printing can travel back in time
s"denotation $this invalid in run ${ctx.runId}. ValidFor: $validFor")
assert(
ctx.runId >= validFor.runId ||
ctx.settings.YtestPickler.value || // mixing test pickler with debug printing can travel back in time
symbol.is(Permanent), // Permanent symbols are valid in all runIds
s"denotation $this invalid in run ${ctx.runId}. ValidFor: $validFor")
var d: SingleDenotation = this
do {
d.validFor = Period(ctx.period.runId, d.validFor.firstPhaseId, d.validFor.lastPhaseId)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TastyFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ object TastyFormat {
| STATIC
| OBJECT
| TRAIT
| ENUM
| LOCAL
| SYNTHETIC
| ARTIFACT
Expand Down
81 changes: 40 additions & 41 deletions compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -509,12 +509,7 @@ class TreePickler(pickler: TastyPickler) {
case tree: Template =>
registerDef(tree.symbol)
writeByte(TEMPLATE)
val (params, rest) = tree.body partition {
case stat: TypeDef => stat.symbol is Flags.Param
case stat: ValOrDefDef =>
stat.symbol.is(Flags.ParamAccessor) && !stat.symbol.isSetter
case _ => false
}
val (params, rest) = decomposeTemplateBody(tree.body)
withLength {
pickleParams(params)
tree.parents.foreach(pickleTree)
Expand Down Expand Up @@ -635,44 +630,48 @@ class TreePickler(pickler: TastyPickler) {

def pickleFlags(flags: FlagSet, isTerm: Boolean)(implicit ctx: Context): Unit = {
import Flags._
if (flags is Private) writeByte(PRIVATE)
if (flags is Protected) writeByte(PROTECTED)
if (flags.is(Final, butNot = Module)) writeByte(FINAL)
if (flags is Case) writeByte(CASE)
if (flags is Override) writeByte(OVERRIDE)
if (flags is Inline) writeByte(INLINE)
if (flags is InlineProxy) writeByte(INLINEPROXY)
if (flags is Macro) writeByte(MACRO)
if (flags is JavaStatic) writeByte(STATIC)
if (flags is Module) writeByte(OBJECT)
if (flags is Enum) writeByte(ENUM)
if (flags is Local) writeByte(LOCAL)
if (flags is Synthetic) writeByte(SYNTHETIC)
if (flags is Artifact) writeByte(ARTIFACT)
if (flags is Scala2x) writeByte(SCALA2X)
def writeModTag(tag: Int) = {
assert(isModifierTag(tag))
writeByte(tag)
}
if (flags is Private) writeModTag(PRIVATE)
if (flags is Protected) writeModTag(PROTECTED)
if (flags.is(Final, butNot = Module)) writeModTag(FINAL)
if (flags is Case) writeModTag(CASE)
if (flags is Override) writeModTag(OVERRIDE)
if (flags is Inline) writeModTag(INLINE)
if (flags is InlineProxy) writeModTag(INLINEPROXY)
if (flags is Macro) writeModTag(MACRO)
if (flags is JavaStatic) writeModTag(STATIC)
if (flags is Module) writeModTag(OBJECT)
if (flags is Enum) writeModTag(ENUM)
if (flags is Local) writeModTag(LOCAL)
if (flags is Synthetic) writeModTag(SYNTHETIC)
if (flags is Artifact) writeModTag(ARTIFACT)
if (flags is Scala2x) writeModTag(SCALA2X)
if (isTerm) {
if (flags is Implicit) writeByte(IMPLICIT)
if (flags is Implied) writeByte(IMPLIED)
if (flags is Erased) writeByte(ERASED)
if (flags.is(Lazy, butNot = Module)) writeByte(LAZY)
if (flags is AbsOverride) { writeByte(ABSTRACT); writeByte(OVERRIDE) }
if (flags is Mutable) writeByte(MUTABLE)
if (flags is Accessor) writeByte(FIELDaccessor)
if (flags is CaseAccessor) writeByte(CASEaccessor)
if (flags is DefaultParameterized) writeByte(DEFAULTparameterized)
if (flags is StableRealizable) writeByte(STABLE)
if (flags is Extension) writeByte(EXTENSION)
if (flags is Given) writeByte(GIVEN)
if (flags is ParamAccessor) writeByte(PARAMsetter)
if (flags is Exported) writeByte(EXPORTED)
if (flags is Implicit) writeModTag(IMPLICIT)
if (flags is Implied) writeModTag(IMPLIED)
if (flags is Erased) writeModTag(ERASED)
if (flags.is(Lazy, butNot = Module)) writeModTag(LAZY)
if (flags is AbsOverride) { writeModTag(ABSTRACT); writeModTag(OVERRIDE) }
if (flags is Mutable) writeModTag(MUTABLE)
if (flags is Accessor) writeModTag(FIELDaccessor)
if (flags is CaseAccessor) writeModTag(CASEaccessor)
if (flags is DefaultParameterized) writeModTag(DEFAULTparameterized)
if (flags is StableRealizable) writeModTag(STABLE)
if (flags is Extension) writeModTag(EXTENSION)
if (flags is Given) writeModTag(GIVEN)
if (flags is ParamAccessor) writeModTag(PARAMsetter)
if (flags is Exported) writeModTag(EXPORTED)
assert(!(flags is Label))
} else {
if (flags is Sealed) writeByte(SEALED)
if (flags is Abstract) writeByte(ABSTRACT)
if (flags is Trait) writeByte(TRAIT)
if (flags is Covariant) writeByte(COVARIANT)
if (flags is Contravariant) writeByte(CONTRAVARIANT)
if (flags is Opaque) writeByte(OPAQUE)
if (flags is Sealed) writeModTag(SEALED)
if (flags is Abstract) writeModTag(ABSTRACT)
if (flags is Trait) writeModTag(TRAIT)
if (flags is Covariant) writeModTag(COVARIANT)
if (flags is Contravariant) writeModTag(CONTRAVARIANT)
if (flags is Opaque) writeModTag(OPAQUE)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
else if (sym.is(ModuleClass))
nameString(sym.name.stripModuleClassSuffix)
else if (hasMeaninglessName(sym))
simpleNameString(sym.owner)
simpleNameString(sym.owner) + idString(sym)
else
nameString(sym)
(keywordText(kindString(sym)) ~~ {
Expand Down
161 changes: 161 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/CompleteJavaEnums.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package dotty.tools.dotc
package transform

import core._
import Names._
import StdNames.{nme, tpnme}
import Types._
import dotty.tools.dotc.transform.MegaPhase._
import Flags._
import Contexts.Context
import Symbols._
import Constants._
import Decorators._
import DenotTransformers._

object CompleteJavaEnums {
val name: String = "completeJavaEnums"

private val nameParamName: TermName = "$name".toTermName
private val ordinalParamName: TermName = "$ordinal".toTermName
}

/** For Scala enums that inherit from java.lang.Enum:
* Add constructor parameters for `name` and `ordinal` to pass from each
* case to the java.lang.Enum class.
*/
class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
import CompleteJavaEnums._
import ast.tpd._

override def phaseName: String = CompleteJavaEnums.name

override def relaxedTypingInGroup: Boolean = true
// Because it adds additional parameters to some constructors

def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context): Type =
if (sym.isConstructor && derivesFromJEnum(sym.owner)) addConstrParams(sym.info)
else tp

/** Is `sym` a Scala enum class that derives (directly) from `java.lang.Enum`?
*/
private def derivesFromJEnum(sym: Symbol)(implicit ctx: Context) =
sym.is(Enum, butNot = Case) &&
sym.info.parents.exists(p => p.typeSymbol == defn.JEnumClass)

/** Add constructor parameters `$name: String` and `$ordinal: Int` to the end of
* the last parameter list of (method- or poly-) type `tp`.
*/
private def addConstrParams(tp: Type)(implicit ctx: Context): Type = tp match {
case tp: PolyType =>
tp.derivedLambdaType(resType = addConstrParams(tp.resType))
case tp: MethodType =>
tp.resType match {
case restpe: MethodType =>
tp.derivedLambdaType(resType = addConstrParams(restpe))
case _ =>
tp.derivedLambdaType(
paramNames = tp.paramNames ++ List(nameParamName, ordinalParamName),
paramInfos = tp.paramInfos ++ List(defn.StringType, defn.IntType))
}
}

/** The list of parameter definitions `$name: String, $ordinal: Int`, in given `owner`
* with given flags (either `Param` or `ParamAccessor`)
*/
private def addedParams(owner: Symbol, flag: FlagSet)(implicit ctx: Context): List[ValDef] = {
val nameParam = ctx.newSymbol(owner, nameParamName, flag | Synthetic, defn.StringType, coord = owner.span)
val ordinalParam = ctx.newSymbol(owner, ordinalParamName, flag | Synthetic, defn.IntType, coord = owner.span)
List(ValDef(nameParam), ValDef(ordinalParam))
}

/** Add arguments `args` to the parent constructor application in `parents` that invokes
* a constructor of `targetCls`,
*/
private def addEnumConstrArgs(targetCls: Symbol, parents: List[Tree], args: List[Tree])(implicit ctx: Context): List[Tree] =
parents.map {
case app @ Apply(fn, args0) if fn.symbol.owner == targetCls => cpy.Apply(app)(fn, args0 ++ args)
case p => p
}

/** 1. If this is a constructor of a enum class that extends, add $name and $ordinal parameters to it.
*
* 2. If this is a $new method that creates simple cases, pass $name and $ordinal parameters
* to the enum superclass. The $new method looks like this:
*
* def $new(..., enumTag: Int, name: String) = {
* class $anon extends E(...) { ... }
* new $anon
* }
*
* After the transform it is expanded to
*
* def $new(..., enumTag: Int, name: String) = {
* class $anon extends E(..., name, enumTag) { ... }
* new $anon
* }
*/
override def transformDefDef(tree: DefDef)(implicit ctx: Context): DefDef = {
val sym = tree.symbol
if (sym.isConstructor && derivesFromJEnum(sym.owner))
cpy.DefDef(tree)(
vparamss = tree.vparamss.init :+ (tree.vparamss.last ++ addedParams(sym, Param)))
else if (sym.name == nme.DOLLAR_NEW && derivesFromJEnum(sym.owner.linkedClass)) {
val Block((tdef @ TypeDef(tpnme.ANON_CLASS, templ: Template)) :: Nil, call) = tree.rhs
val args = tree.vparamss.last.takeRight(2).map(param => ref(param.symbol)).reverse
val templ1 = cpy.Template(templ)(
parents = addEnumConstrArgs(sym.owner.linkedClass, templ.parents, args))
cpy.DefDef(tree)(
rhs = cpy.Block(tree.rhs)(cpy.TypeDef(tdef)(tdef.name, templ1) :: Nil, call))
}
else tree
}

/** 1. If this is an enum class, add $name and $ordinal parameters to its
* parameter accessors and pass them on to the java.lang.Enum constructor,
* replacing the dummy arguments that were passed before.
*
* 2. If this is an anonymous class that implement a value enum case,
* pass $name and $ordinal parameters to the enum superclass. The class
* looks like this:
*
* class $anon extends E(...) {
* ...
* def enumTag = N
* def toString = S
* ...
* }
*
* After the transform it is expanded to
*
* class $anon extends E(..., N, S) {
* "same as before"
* }
*/
override def transformTemplate(templ: Template)(implicit ctx: Context): Template = {
val cls = templ.symbol.owner
if (derivesFromJEnum(cls)) {
val (params, rest) = decomposeTemplateBody(templ.body)
val addedDefs = addedParams(cls, ParamAccessor)
val addedSyms = addedDefs.map(_.symbol.entered)
val parents1 = templ.parents.map {
case app @ Apply(fn, _) if fn.symbol.owner == defn.JEnumClass =>
cpy.Apply(app)(fn, addedSyms.map(ref))
case p => p
}
cpy.Template(templ)(
parents = parents1,
body = params ++ addedDefs ++ rest)
}
else if (cls.isAnonymousClass && cls.owner.is(EnumCase) && derivesFromJEnum(cls.owner.owner.linkedClass)) {
def rhsOf(name: TermName) =
templ.body.collect {
case mdef: DefDef if mdef.name == name => mdef.rhs
}.head
val args = List(rhsOf(nme.toString_), rhsOf(nme.enumTag))
cpy.Template(templ)(
parents = addEnumConstrArgs(cls.owner.owner.linkedClass, templ.parents, args))
}
else templ
}
}
22 changes: 14 additions & 8 deletions compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1046,20 +1046,26 @@ trait Checking {
ctx.error(em"$what $msg", posd.sourcePos)
}

/** Check that all case classes that extend `scala.Enum` are `enum` cases */
/** 1. Check that all case classes that extend `scala.Enum` are `enum` cases
* 2. Check that case class `enum` cases do not extend java.lang.Enum.
*/
def checkEnum(cdef: untpd.TypeDef, cls: Symbol, firstParent: Symbol)(implicit ctx: Context): Unit = {
import untpd.modsDeco
def isEnumAnonCls =
cls.isAnonymousClass &&
cls.owner.isTerm &&
(cls.owner.flagsUNSAFE.is(Case) || cls.owner.name == nme.DOLLAR_NEW)
if (!cdef.mods.isEnumCase && !isEnumAnonCls) {
// Since enums are classes and Namer checks that classes don't extend multiple classes, we only check the class
// parent.
//
// Unlike firstParent.derivesFrom(defn.EnumClass), this test allows inheriting from `Enum` by hand;
// see enum-List-control.scala.
if (cls.is(Case) || firstParent.is(Enum))
if (!isEnumAnonCls) {
if (cdef.mods.isEnumCase) {
if (cls.derivesFrom(defn.JEnumClass))
ctx.error(em"parameterized case is not allowed in an enum that extends java.lang.Enum", cdef.sourcePos)
}
else if (cls.is(Case) || firstParent.is(Enum))
// Since enums are classes and Namer checks that classes don't extend multiple classes, we only check the class
// parent.
//
// Unlike firstParent.derivesFrom(defn.EnumClass), this test allows inheriting from `Enum` by hand;
// see enum-List-control.scala.
ctx.error(ClassCannotExtendEnum(cls, firstParent), cdef.sourcePos)
}
}
Expand Down
20 changes: 3 additions & 17 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -498,23 +498,9 @@ class Namer { typer: Typer =>
recur(expanded(origStat))
}

/** Determines whether this field holds an enum constant.
* To qualify, the following conditions must be met:
* - The field's class has the ENUM flag set
* - The field's class extends java.lang.Enum
* - The field has the ENUM flag set
* - The field is static
* - The field is stable
*/
def isEnumConstant(vd: ValDef)(implicit ctx: Context): Boolean = {
// val ownerHasEnumFlag =
// Necessary to check because scalac puts Java's static members into the companion object
// while Scala's enum constants live directly in the class.
// We don't check for clazz.superClass == JavaEnumClass, because this causes a illegal
// cyclic reference error. See the commit message for details.
// if (ctx.compilationUnit.isJava) ctx.owner.companionClass.is(Enum) else ctx.owner.is(Enum)
vd.mods.is(JavaEnumValue) // && ownerHasEnumFlag
}
/** Determines whether this field holds an enum constant. */
def isEnumConstant(vd: ValDef)(implicit ctx: Context): Boolean =
vd.mods.is(JavaEnumValue)

/** Add child annotation for `child` to annotations of `cls`. The annotation
* is added at the correct insertion point, so that Child annotations appear
Expand Down
Loading