Skip to content

Commit

Permalink
Generate MonoType and fromProduct for generic products
Browse files Browse the repository at this point in the history
Generate MonoType and fromProduct members for generic products.
  • Loading branch information
odersky committed May 19, 2019
1 parent 2f295c4 commit 8487127
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 52 deletions.
66 changes: 46 additions & 20 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import reporting.diagnostic.messages._
import reporting.trace
import annotation.constructorOnly
import printing.Formatting.hl
import config.Printers

import scala.annotation.internal.sharable

Expand Down Expand Up @@ -51,7 +52,7 @@ object desugar {
private type VarInfo = (NameTree, Tree)

/** Is `name` the name of a method that can be invalidated as a compiler-generated
* case class method that clashes with a user-defined method?
* case class method if it clashes with a user-defined method?
*/
def isRetractableCaseClassMethodName(name: Name)(implicit ctx: Context): Boolean = name match {
case nme.apply | nme.unapply | nme.unapplySeq | nme.copy => true
Expand Down Expand Up @@ -394,6 +395,10 @@ object desugar {
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.

/** The untyped analogue of SymUtils.isGenericProduct */
val isGenericProduct =
mods.is(Case, butNot = Abstract) && constr1.vparamss.length == 1 && !isValueClass

val originalTparams = constr1.tparams
val originalVparamss = constr1.vparamss
lazy val derivedEnumParams = enumClass.typeParams.map(derivedTypeParam)
Expand Down Expand Up @@ -585,11 +590,16 @@ object desugar {
else Nil
}

def mirrorMemberType(str: String) =
Select(Select(scalaDot("deriving".toTermName), "Mirror".toTermName), str.toTypeName)

var parents1 = parents
if (isEnumCase && parents.isEmpty)
parents1 = enumClassTypeRef :: Nil
if (isCaseClass | isCaseObject)
if (isCaseClass)
parents1 = parents1 :+ scalaDot(str.Product.toTypeName) :+ scalaDot(nme.Serializable.toTypeName)
else if (isCaseObject)
parents1 = parents1 :+ mirrorMemberType("Singleton") :+ scalaDot(nme.Serializable.toTypeName)
else if (isObject)
parents1 = parents1 :+ scalaDot(nme.Serializable.toTypeName)
if (isEnum)
Expand All @@ -600,11 +610,11 @@ object desugar {
if (mods.is(Module)) (impl.derived, Nil) else (Nil, impl.derived)

// The thicket which is the desugared version of the companion object
// synthetic object C extends parentTpt derives class-derived { defs }
def companionDefs(parentTpt: Tree, defs: List[Tree]) = {
// synthetic object C extends parentTpts derives class-derived { defs }
def companionDefs(parentTpts: List[Tree], defs: List[Tree]) = {
val mdefs = moduleDef(
ModuleDef(
className.toTermName, Template(emptyConstructor, parentTpt :: Nil, companionDerived, EmptyValDef, defs))
className.toTermName, Template(emptyConstructor, parentTpts, companionDerived, EmptyValDef, defs))
.withMods(companionMods | Synthetic))
.withSpan(cdef.span).toList
if (companionDerived.nonEmpty)
Expand All @@ -627,6 +637,7 @@ object desugar {
// For all other classes, the parent is AnyRef.
val companions =
if (isCaseClass) {

// The return type of the `apply` method, and an (empty or singleton) list
// of widening coercions
val (applyResultTpt, widenDefs) =
Expand Down Expand Up @@ -654,38 +665,53 @@ object desugar {
// todo: also use anyRef if constructor has a dependent method type (or rule that out)!
(constrVparamss :\ classTypeRef) (
(vparams, restpe) => Function(vparams map (_.tpt), restpe))
val companionParents =
if (isGenericProduct) companionParent :: mirrorMemberType("Product") :: Nil
else companionParent :: Nil
def widenedCreatorExpr =
(creatorExpr /: widenDefs)((rhs, meth) => Apply(Ident(meth.name), rhs :: Nil))
val applyMeths =
if (mods is Abstract) Nil
else {
val copiedFlagsMask = DefaultParameterized | (copiedAccessFlags & Private)
val appMods = {
val mods = Modifiers(Synthetic | constr1.mods.flags & copiedFlagsMask)
if (restrictedAccess) mods.withPrivateWithin(constr1.mods.privateWithin)
else mods
def applyDef = {
val copiedFlagsMask = DefaultParameterized | (copiedAccessFlags & Private)
val appMods = {
val mods = Modifiers(Synthetic | constr1.mods.flags & copiedFlagsMask)
if (restrictedAccess) mods.withPrivateWithin(constr1.mods.privateWithin)
else mods
}
DefDef(nme.apply, derivedTparams, derivedVparamss, applyResultTpt, widenedCreatorExpr)
.withMods(appMods)
}
val app = DefDef(nme.apply, derivedTparams, derivedVparamss, applyResultTpt, widenedCreatorExpr)
.withMods(appMods)
app :: widenDefs
applyDef :: widenDefs
}

val monoTypeDefs =
if (isGenericProduct) {
val monoType = appliedTypeTree(classTycon, constrTparams.map(_ => TypeBoundsTree(EmptyTree, EmptyTree)))
TypeDef(tpnme.MonoType, monoType).withMods(synthetic) :: Nil
}
else Nil

val unapplyMeth = {
val hasRepeatedParam = constrVparamss.head.exists {
case ValDef(_, tpt, _) => isRepeated(tpt)
}
val methName = if (hasRepeatedParam) nme.unapplySeq else nme.unapply
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name)
DefDef(methName, derivedTparams, (unapplyParam :: Nil) :: Nil, TypeTree(), unapplyRHS)
val param = makeSyntheticParameter(tpt = classTypeRef)
val rhs = if (arity == 0) Literal(Constant(true)) else Ident(param.name)
DefDef(methName, derivedTparams, (param :: Nil) :: Nil, TypeTree(), rhs)
.withMods(synthetic)
}
companionDefs(companionParent, applyMeths ::: unapplyMeth :: companionMembers)
companionDefs(
companionParents,
applyMeths ::: unapplyMeth :: monoTypeDefs ::: companionMembers)
}
else if (companionMembers.nonEmpty || companionDerived.nonEmpty || isEnum)
companionDefs(anyRef, companionMembers)
companionDefs(anyRef :: Nil, companionMembers)
else if (isValueClass) {
impl.constr.vparamss match {
case (_ :: Nil) :: _ => companionDefs(anyRef, Nil)
case (_ :: Nil) :: _ => companionDefs(anyRef :: Nil, Nil)
case _ => Nil // error will be emitted in typer
}
}
Expand Down Expand Up @@ -765,7 +791,7 @@ object desugar {
}

flatTree(cdef1 :: companions ::: implicitWrappers)
}
}.reporting(res => i"desugared: $res", Printers.desugar)

/** Expand
*
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ object Printers {
val cyclicErrors: Printer = noPrinter
val debug = noPrinter // no type annotation here to force inlining
val derive: Printer = noPrinter
val desugar: Printer = noPrinter
val dottydoc: Printer = noPrinter
val exhaustivity: Printer = noPrinter
val gadts: Printer = noPrinter
Expand Down
10 changes: 10 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,16 @@ class Definitions {
lazy val ModuleSerializationProxyConstructor: TermSymbol =
ModuleSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(TypeBounds.empty)))

//lazy val MirrorType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror")
lazy val Mirror_ProductType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror.Product")
def Mirror_ProductClass(implicit ctx: Context): ClassSymbol = Mirror_ProductType.symbol.asClass

lazy val Mirror_Product_fromProductR: TermRef = Mirror_ProductClass.requiredMethodRef(nme.fromProduct)
def Mirror_Product_fromProduct(implicit ctx: Context): Symbol = Mirror_Product_fromProductR.symbol

lazy val Mirror_SingletonType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror.Singleton")
def Mirror_SingletonClass(implicit ctx: Context): ClassSymbol = Mirror_SingletonType.symbol.asClass

lazy val GenericType: TypeRef = ctx.requiredClassRef("scala.reflect.Generic")
def GenericClass(implicit ctx: Context): ClassSymbol = GenericType.symbol.asClass
lazy val ShapeType: TypeRef = ctx.requiredClassRef("scala.compiletime.Shape")
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ object StdNames {
val longHash: N = "longHash"
val MatchCase: N = "MatchCase"
val Modifiers: N = "Modifiers"
val MonoType: N = "MonoType"
val NestedAnnotArg: N = "NestedAnnotArg"
val NoFlags: N = "NoFlags"
val NoPrefix: N = "NoPrefix"
Expand Down Expand Up @@ -432,6 +433,7 @@ object StdNames {
val flagsFromBits : N = "flagsFromBits"
val flatMap: N = "flatMap"
val foreach: N = "foreach"
val fromProduct: N = "fromProduct"
val genericArrayOps: N = "genericArrayOps"
val genericClass: N = "genericClass"
val get: N = "get"
Expand Down
11 changes: 10 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import StdNames._
import NameKinds._
import Flags._
import Annotations._
import ValueClasses.isDerivedValueClass

import language.implicitConversions
import scala.annotation.tailrec
Expand Down Expand Up @@ -59,10 +60,18 @@ class SymUtils(val self: Symbol) extends AnyVal {

def isSuperAccessor(implicit ctx: Context): Boolean = self.name.is(SuperAccessorName)

/** A type or term parameter or a term parameter accessor */
/** Is this a type or term parameter or a term parameter accessor? */
def isParamOrAccessor(implicit ctx: Context): Boolean =
self.is(Param) || self.is(ParamAccessor)

/** Is this a case class for which a product mirror is generated?
* Excluded are value classes, abstract classes and case classes with more than one
* parameter section. See also: desugar.isGenericProduct */
def isGenericProduct(implicit ctx: Context): Boolean =
self.is(CaseClass, butNot = Abstract) &&
self.primaryConstructor.info.paramInfoss.length == 1 &&
!isDerivedValueClass(self)

/** If this is a constructor, its owner: otherwise this. */
final def skipConstructor(implicit ctx: Context): Symbol =
if (self.isConstructor) self.owner else self
Expand Down
59 changes: 59 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import DenotTransformers._
import Decorators._
import NameOps._
import Annotations.Annotation
import typer.ProtoTypes.constrained
import ast.untpd
import ValueClasses.isDerivedValueClass
import SymUtils._

/** Synthetic method implementations for case classes, case objects,
* and value classes.
Expand Down Expand Up @@ -38,18 +41,21 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
private[this] var myValueSymbols: List[Symbol] = Nil
private[this] var myCaseSymbols: List[Symbol] = Nil
private[this] var myCaseModuleSymbols: List[Symbol] = Nil
private[this] var myProductMirrorSymbols: List[Symbol] = Nil

private def initSymbols(implicit ctx: Context) =
if (myValueSymbols.isEmpty) {
myValueSymbols = List(defn.Any_hashCode, defn.Any_equals)
myCaseSymbols = myValueSymbols ++ List(defn.Any_toString, defn.Product_canEqual,
defn.Product_productArity, defn.Product_productPrefix, defn.Product_productElement)
myCaseModuleSymbols = myCaseSymbols.filter(_ ne defn.Any_equals)
myProductMirrorSymbols = List(defn.Mirror_Product_fromProduct)
}

def valueSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myValueSymbols }
def caseSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseSymbols }
def caseModuleSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseModuleSymbols }
def productMirrorSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myProductMirrorSymbols }

/** If this is a case or value class, return the appropriate additional methods,
* otherwise return nothing.
Expand All @@ -66,6 +72,7 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
else caseSymbols
}
else if (isDerivedValueClass(clazz)) valueSymbols
else if (clazz.is(Module) && clazz.linkedClass.isGenericProduct) productMirrorSymbols
else Nil

def syntheticDefIfMissing(sym: Symbol): List[Tree] = {
Expand All @@ -78,6 +85,7 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
val synthetic = sym.copy(
owner = clazz,
flags = sym.flags &~ Deferred | Synthetic | Override,
info = clazz.thisType.memberInfo(sym),
coord = clazz.coord).enteredAfter(thisPhase).asTerm

def forwardToRuntime(vrefss: List[List[Tree]]): Tree =
Expand All @@ -95,6 +103,10 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
case nme.productArity => vrefss => Literal(Constant(accessors.length))
case nme.productPrefix => ownName
case nme.productElement => vrefss => productElementBody(accessors.length, vrefss.head.head)
case nme.fromProduct =>
vrefss =>
fromProductBody(accessors, vrefss.head.head)
.ensureConforms(synthetic.info.finalResultType)
}
ctx.log(s"adding $synthetic to $clazz at ${ctx.phase}")
DefDef(synthetic, syntheticRHS(ctx.withOwner(synthetic))).withSpan(ctx.owner.span.focus)
Expand Down Expand Up @@ -138,6 +150,53 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
Match(index, (cases :+ defaultCase).toList)
}

/** The class
*
* ```
* case class C[T <: U](x: T, y: String*)
* ```
*
* gets the `fromProduct` method:
*
* ```
* def fromProduct(x$0: Product): MonoType =
* new C[U](
* x$0.productElement(0).asInstanceOf[U],
* x$0.productElement(1).asInstanceOf[Seq[String]]: _*)
* ```
* where
* ```
* type MonoType = C[_]
* ```
*/
def fromProductBody(accessors: List[Symbol], prod: Tree)(implicit ctx: Context): Tree = {
val caseClass = clazz.linkedClass
val (classRef, methTpe) =
caseClass.primaryConstructor.info match {
case tl: PolyType =>
val (tl1, tpts) = constrained(tl, untpd.EmptyTree, alwaysAddTypeVars = true)
val targs =
for (tpt <- tpts) yield
tpt.tpe match {
case tvar: TypeVar => tvar.instantiate(fromBelow = false)
}
(caseClass.typeRef.appliedTo(targs), tl.instantiate(targs))
case methTpe =>
(caseClass.typeRef, methTpe)
}
methTpe match {
case methTpe: MethodType =>
val elems =
for ((formal, idx) <- methTpe.paramInfos.zipWithIndex) yield {
val elem =
prod.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
.ensureConforms(formal.underlyingIfRepeated(isJava = false))
if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem
}
New(classRef, elems)
}
}

/** The class
*
* ```
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ class Namer { typer: Typer =>
case _ => tree
}

/** For all class definitions `stat` in `xstats`: If the companion class if
/** For all class definitions `stat` in `xstats`: If the companion class is
* not also defined in `xstats`, invalidate it by setting its info to
* NoType.
*/
Expand Down Expand Up @@ -702,7 +702,7 @@ class Namer { typer: Typer =>
// If a top-level object or class has no companion in the current run, we
// enter a dummy companion (`denot.isAbsent` returns true) in scope. This
// ensures that we never use a companion from a previous run or from the
// classpath. See tests/pos/false-companion for an example where this
// class path. See tests/pos/false-companion for an example where this
// matters.
if (ctx.owner.is(PackageClass)) {
for (cdef @ TypeDef(moduleName, _) <- moduleDef.values) {
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ trait TypeAssigner {
private def toRepeated(tree: Tree, from: ClassSymbol)(implicit ctx: Context): Tree =
Typed(tree, TypeTree(tree.tpe.widen.translateParameterized(from, defn.RepeatedParamClass)))

def seqToRepeated(tree: Tree)(implicit ctx: Context): Tree = toRepeated(tree, defn.SeqClass)
def seqToRepeated(tree: Tree)(implicit ctx: Context): Tree = toRepeated(tree, defn.SeqClass)

def arrayToRepeated(tree: Tree)(implicit ctx: Context): Tree = toRepeated(tree, defn.ArrayClass)
def arrayToRepeated(tree: Tree)(implicit ctx: Context): Tree = toRepeated(tree, defn.ArrayClass)

/** A denotation exists really if it exists and does not point to a stale symbol. */
final def reallyExists(denot: Denotation)(implicit ctx: Context): Boolean = try
Expand Down
Loading

0 comments on commit 8487127

Please sign in to comment.