Skip to content

Commit

Permalink
Mirror infrastructure for generic sum types
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky authored and milessabin committed May 30, 2019
1 parent c6c0f71 commit 5c924ca
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 34 deletions.
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Expand Up @@ -697,6 +697,9 @@ class Definitions {
lazy val Mirror_Product_fromProductR: TermRef = Mirror_ProductClass.requiredMethodRef(nme.fromProduct)
def Mirror_Product_fromProduct(implicit ctx: Context): Symbol = Mirror_Product_fromProductR.symbol

lazy val Mirror_SumType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror.Sum")
def Mirror_SumClass(implicit ctx: Context): ClassSymbol = Mirror_SumType.symbol.asClass

lazy val Mirror_SingletonType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror.Singleton")
def Mirror_SingletonClass(implicit ctx: Context): ClassSymbol = Mirror_SingletonType.symbol.asClass

Expand Down
38 changes: 34 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/SymUtils.scala
Expand Up @@ -68,10 +68,36 @@ class SymUtils(val self: Symbol) extends AnyVal {
* Excluded are value classes, abstract classes and case classes with more than one
* parameter section.
*/
def isGenericProduct(implicit ctx: Context): Boolean =
self.is(CaseClass, butNot = Abstract) &&
self.primaryConstructor.info.paramInfoss.length == 1 &&
!isDerivedValueClass(self)
def whyNotGenericProduct(implicit ctx: Context): String =
if (!self.is(CaseClass)) "it is not a case class"
else if (self.is(Abstract)) "it is an abstract class"
else if (self.primaryConstructor.info.paramInfoss.length != 1) "it takes more than one parameter list"
else if (isDerivedValueClass(self)) "it is a value class"
else ""

def isGenericProduct(implicit ctx: Context): Boolean = whyNotGenericProduct.isEmpty

/** Is this a sealed class or trait for which a sum mirror is generated?
* Excluded are
*/
def whyNotGenericSum(implicit ctx: Context): String =
if (!self.is(Sealed))
s"it is not a sealed ${if (self.is(Trait)) "trait" else "class"}"
else {
val children = self.children
def problem(child: Symbol) =
if (child == self) "it has anonymous or inaccessible subclasses"
else if (!child.isClass) ""
else {
val s = child.whyNotGenericProduct
if (s.isEmpty) s
else "its child $child is not a generic product because $s"
}
if (children.isEmpty) "it does not have subclasses"
else children.filter(_.isClass).map(problem).find(!_.isEmpty).getOrElse("")
}

def isGenericSum(implicit ctx: Context): Boolean = whyNotGenericSum.isEmpty

/** If this is a constructor, its owner: otherwise this. */
final def skipConstructor(implicit ctx: Context): Symbol =
Expand Down Expand Up @@ -161,6 +187,10 @@ class SymUtils(val self: Symbol) extends AnyVal {
else owner.isLocal
}

/** The typeRef with wildcard arguments for each type parameter */
def rawTypeRef(implicit ctx: Context) =
self.typeRef.appliedTo(self.typeParams.map(_ => TypeBounds.empty))

/** Is symbol a quote operation? */
def isQuote(implicit ctx: Context): Boolean =
self == defn.InternalQuoted_exprQuote || self == defn.InternalQuoted_typeQuote
Expand Down
96 changes: 67 additions & 29 deletions compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala
Expand Up @@ -54,9 +54,10 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
def caseSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseSymbols }
def caseModuleSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseModuleSymbols }

private def alreadyDefined(sym: Symbol, clazz: ClassSymbol)(implicit ctx: Context): Boolean = {
private def existingDef(sym: Symbol, clazz: ClassSymbol)(implicit ctx: Context): Symbol = {
val existing = sym.matchingMember(clazz.thisType)
existing.exists && !(existing == sym || existing.is(Deferred))
if (existing != sym && !existing.is(Deferred)) existing
else NoSymbol
}

private def synthesizeDef(sym: TermSymbol, rhsFn: List[List[Tree]] => Context => Tree)(implicit ctx: Context): Tree =
Expand All @@ -80,7 +81,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
else Nil

def syntheticDefIfMissing(sym: Symbol): List[Tree] =
if (alreadyDefined(sym, clazz)) Nil else syntheticDef(sym) :: Nil
if (existingDef(sym, clazz).exists) Nil else syntheticDef(sym) :: Nil

def syntheticDef(sym: Symbol): Tree = {
val synthetic = sym.copy(
Expand Down Expand Up @@ -344,42 +345,79 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
}
}

/** For an enum T:
*
* def ordinal(x: MonoType) = x.enumTag
*
* For sealed trait with children of normalized types C_1, ..., C_n:
*
* def ordinal(x: MonoType) = x match {
* case _: C_1 => 0
* ...
* case _: C_n => n - 1
*
* Here, the normalized type of a class C is C[_, ...., _] with
* a wildcard for each type parameter. The normalized type of an object
* O is O.type.
*/
def ordinalBody(cls: Symbol, param: Tree)(implicit ctx: Context): Tree =
if (cls.is(Enum)) param.select(nme.enumTag)
else {
val cases =
for ((child, idx) <- cls.children.zipWithIndex) yield {
val patType = if (child.isTerm) child.termRef else child.rawTypeRef
val pat = Typed(untpd.Ident(nme.WILDCARD).withType(patType), TypeTree(patType))
CaseDef(pat, EmptyTree, Literal(Constant(idx)))
}
Match(param, cases)
}

/** - If `impl` is the companion of a generic sum, add `deriving.Mirror.Sum` parent
* and `MonoType` and `ordinal` members.
* - If `impl` is the companion of a generic product, add `deriving.Mirror.Product` parent
* and `MonoType` and `fromProduct` members.
*/
def addMirrorSupport(impl: Template)(implicit ctx: Context): Template = {
val clazz = ctx.owner.asClass
var newBody = serializableObjectMethod(clazz) ::: caseAndValueMethods(clazz) ::: impl.body
val linked = clazz.linkedClass

var newBody = impl.body
var newParents = impl.parents
def addParent(parent: Type) = {
def addParent(parent: Type): Unit = {
newParents = newParents :+ TypeTree(parent)
val oldClassInfo = clazz.classInfo
val newClassInfo = oldClassInfo.derivedClassInfo(
classParents = oldClassInfo.classParents :+ parent)
clazz.copySymDenotation(info = newClassInfo).installAfter(thisPhase)
}
def addMethod(name: TermName, info: Type, body: (Symbol, Tree, Context) => Tree): Unit = {
val meth = ctx.newSymbol(clazz, name, Synthetic | Method, info, coord = clazz.coord)
if (!existingDef(meth, clazz).exists) {
meth.entered
newBody = newBody :+
synthesizeDef(meth, vrefss => ctx => body(linked, vrefss.head.head, ctx))
}
}
lazy val monoType = {
val monoType =
ctx.newSymbol(clazz, tpnme.MonoType, Synthetic, TypeAlias(linked.rawTypeRef), coord = clazz.coord)
existingDef(monoType, clazz).orElse {
newBody = newBody :+ TypeDef(monoType).withSpan(ctx.owner.span.focus)
monoType.entered
}
}
if (clazz.is(Module)) {
if (clazz.is(Case)) addParent(defn.Mirror_SingletonType)
else {
val linked = clazz.linkedClass
if (linked.isGenericProduct) {
addParent(defn.Mirror_ProductType)
val rawClassType =
linked.typeRef.appliedTo(linked.typeParams.map(_ => TypeBounds.empty))
val monoType =
ctx.newSymbol(clazz, tpnme.MonoType, Synthetic, TypeAlias(rawClassType), coord = clazz.coord)
if (!alreadyDefined(monoType, clazz)) {
monoType.entered
newBody = newBody :+ TypeDef(monoType).withSpan(ctx.owner.span.focus)
}
val fromProduct =
ctx.newSymbol(clazz, nme.fromProduct, Synthetic | Method,
info = MethodType(defn.ProductType :: Nil, monoType.typeRef), coord = clazz.coord)
if (!alreadyDefined(fromProduct, clazz)) {
fromProduct.entered
newBody = newBody :+
synthesizeDef(fromProduct, vrefss => ctx =>
fromProductBody(linked, vrefss.head.head)(ctx)
.ensureConforms(rawClassType)) // t4758.scala or i3381.scala are examples where a cast is needed
}
}
if (clazz.is(Case))
addParent(defn.Mirror_SingletonType)
else if (linked.isGenericProduct) {
addParent(defn.Mirror_ProductType)
addMethod(nme.fromProduct, MethodType(defn.ProductType :: Nil, monoType.typeRef),
fromProductBody(_, _)(_).ensureConforms(monoType.typeRef)) // t4758.scala or i3381.scala are examples where a cast is needed
}
else if (linked.isGenericSum) {
addParent(defn.Mirror_SumType)
addMethod(nme.ordinal, MethodType(monoType.typeRef :: Nil, defn.IntType),
ordinalBody(_, _)(_))
}
}

Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Namer.scala
Expand Up @@ -868,8 +868,10 @@ class Namer { typer: Typer =>
val child = if (denot.is(Module)) denot.sourceModule else denot.symbol
register(child, parent)
}
else if (denot.is(CaseVal, butNot = Method | Module))
else if (denot.is(CaseVal, butNot = Method | Module)) {
assert(denot.is(Enum), denot)
register(denot.symbol, denot.info)
}
}

/** Intentionally left without `implicit ctx` parameter. We need
Expand Down

0 comments on commit 5c924ca

Please sign in to comment.