Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive a type class only for wrapper types (unary product types) #278

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/core/interface.scala
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ abstract class ReadOnlyCaseClass[Typeclass[_], Type](
final def typeAnnotations: Seq[Any] = typeAnnotationsArray
}


/** [[CaseClass]] contains all information that exists in a [[ReadOnlyCaseClass]], as well as methods and context
* required for construct an instance of this case class/object (e.g. default values for constructor parameters)
*
Expand Down Expand Up @@ -440,6 +441,25 @@ final case class TypeName(owner: String, short: String, typeArguments: Seq[TypeN
*/
final class debug(typeNamePart: String = "") extends scala.annotation.StaticAnnotation

object typeValidation {
/**
* This annotation can be attached to the `combine` method of a type class companion.
* If specified, it will check that we only derive type classes for types with more than `n` members
*
* @param n inclusive, exactly `n` members is fine
*/
final class minMembers(n: Int) extends scala.annotation.StaticAnnotation

/**
* This annotation can be attached to the `combine` method of a type class companion.
* If specified, it will check that we only derive type classes for types with less than `n` members
*
* @param n inclusive, exactly `n` members is fine
*/
final class maxMembers(n: Int) extends scala.annotation.StaticAnnotation

}

private[magnolia] final case class EarlyExit[E](e: E) extends Exception with util.control.NoStackTrace

object MagnoliaUtil {
Expand Down
87 changes: 67 additions & 20 deletions src/core/magnolia.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,52 @@ object Magnolia {
val prefixType = c.prefix.tree.tpe
val prefixObject = prefixType.typeSymbol
val prefixName = prefixObject.name.decodedName
val prefixBaseClasses = c.prefix.tree.tpe.baseClasses

def extractMethod(termName: String): Option[MethodSymbol] = {
val term = TermName(termName)
prefixBaseClasses
.find(cls => cls.asType.toType.decl(term) != NoSymbol)
.map(cls => cls.asType.toType.decl(term).asTerm.asMethod)
}

val combineMethodOpt = extractMethod("combine")

object validate {
val MinMembersTpe = typeOf[typeValidation.minMembers]
val MaxMembersTpe = typeOf[typeValidation.maxMembers]

val annotations = combineMethodOpt match {
case Some(combineMethod) => combineMethod.annotations
case None => Nil
}

val minMembersOpt: Option[Int] =
annotations
.collectFirst { case a if a.tree.tpe <:< MinMembersTpe => a.tree.children(1) }
.collectFirst { case Literal(Constant(arg: Int)) => arg }

val maxMembersOpt: Option[Int] =
annotations
.collectFirst { case a if a.tree.tpe <:< MaxMembersTpe => a.tree.children(1) }
.collectFirst { case Literal(Constant(arg: Int)) => arg }

def apply(members: List[TermSymbol]): Unit = {
val numMembers = members.length

minMembersOpt match {
case Some(min) if numMembers < min =>
error(s"$genericType is not a valid type for $prefixName.Typeclass because at least $min members are required (it has $numMembers)")
case _ => ()
}

maxMembersOpt match {
case Some(max) if numMembers > max =>
error(s"$genericType is not a valid type for $prefixName.Typeclass because at no more than $max members are required (it has $numMembers)")
case _ => ()
}
}
}

val debug = c.macroApplication.symbol.annotations
.find(_.tree.tpe <:< DebugTpe)
Expand Down Expand Up @@ -226,27 +272,17 @@ object Magnolia {
annotationTrees(typeAnnotations)
}

def checkMethod(termName: String, category: String, expected: String): Unit = {
val firstParamBlock = extractParameterBlockFor(termName, category)
if (firstParamBlock.lengthCompare(1) != 0)
error(s"the method `$termName` should take a single parameter of type $expected")
}

def extractParameterBlockFor(termName: String, category: String): List[Symbol] = {
val term = TermName(termName)
val classWithTerm = c.prefix.tree.tpe.baseClasses
.find(cls => cls.asType.toType.decl(term) != NoSymbol)
.getOrElse(error(s"the method `$termName` must be defined on the derivation $prefixObject to derive typeclasses for $category"))

classWithTerm.asType.toType.decl(term).asTerm.asMethod.paramLists.head
}
lazy val (isReadOnly, caseClassSymbol, paramSymbol) = {
val combine = combineMethodOpt getOrElse {
error(s"the method `combine` must be defined on the derivation $prefixObject to derive typeclasses for case classes")
}

lazy val (isReadOnly, caseClassSymbol, paramSymbol) =
extractParameterBlockFor("combine", "case classes").headOption.map(_.typeSignature.typeSymbol) match {
combine.paramLists.head.headOption.map(_.typeSignature.typeSymbol) match {
case Some(ReadOnlyCaseClassSym) => (true, ReadOnlyCaseClassSym, ReadOnlyParamSym)
case Some(CaseClassSym) => (false, CaseClassSym, ParamSym)
case Some(CaseClassSym) => (false, CaseClassSym, ParamSym)
case _ => error("Parameter for `combine` needs be either magnolia.CaseClass or magnolia.ReadOnlyCaseClass")
}
}

// fullAuto means we should directly infer everything, including external
// members of the ADT, that isn't inferred by the compiler.
Expand Down Expand Up @@ -368,6 +404,8 @@ object Magnolia {
val result = if (isRefinedType) {
error(s"could not infer $prefixName.Typeclass for refined type $genericType")
} else if (isCaseObject) {
validate(members = Nil)

val classBody = if (isReadOnly) List(EmptyTree) else {
val module = Ident(genericType.typeSymbol.asClass.module)
List(
Expand Down Expand Up @@ -402,6 +440,8 @@ object Magnolia {
else { case p: TermSymbol if p.isCaseAccessor && !p.isMethod => p }
)

validate(caseClassParameters)

val (factoryObject, factoryMethod) = {
if (isReadOnly && isValueClass) ReadOnlyParamObj -> TermName("valueParam")
else if (isReadOnly) ReadOnlyParamObj -> TermName("apply")
Expand Down Expand Up @@ -561,7 +601,14 @@ object Magnolia {
})
}""")
} else if (isSealedTrait) {
checkMethod("dispatch", "sealed traits", "SealedTrait[Typeclass, _]")
val firstParamBlock = extractMethod("dispatch") match {
case Some(dispatch) => dispatch.paramLists.head
case None => error(s"the method `dispatch` must be defined on the derivation $prefixObject to derive typeclasses for sealed traits")
}

if (firstParamBlock.lengthCompare(1) != 0)
error("the method `dispatch` should take a single parameter of type SealedTrait[Typeclass, _]")

val genericSubtypes = knownSubclassesOf(classType.get).toList.sortBy(_.fullName)
val subtypes = genericSubtypes.flatMap { sub =>
val subType = sub.asType.toType // FIXME: Broken for path dependent types
Expand Down Expand Up @@ -613,7 +660,7 @@ object Magnolia {
))
}""")
} else if (!typeSymbol.isParameter) {
c.prefix.tree.tpe.baseClasses
prefixBaseClasses
.find { cls =>
cls.asType.toType.decl(TermName("fallback")) != NoSymbol
}.map { _ =>
Expand All @@ -628,7 +675,7 @@ object Magnolia {
}"""
}

val typeDefs = prefixType.baseClasses.flatMap { baseClass =>
val typeDefs = prefixBaseClasses.flatMap { baseClass =>
baseClass.asType.toType.decls.collectFirst {
case tpe: TypeSymbol if tpe.name == TypeClassNme =>
tpe.toType.asSeenFrom(prefixType, baseClass)
Expand Down
65 changes: 65 additions & 0 deletions src/examples/wrappers.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*

Magnolia, version 0.17.0. Copyright 2018-20 Jon Pretty, Propensive OÜ.

The primary distribution site is: https://propensive.com/

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and limitations under the License.

*/
package magnolia.examples
import magnolia._
import scala.language.experimental.macros

/* automatically derived only for wrapper types (unary product types) */
trait ToString[A] {
def str(a: A): String
}

object ToString {
def apply[A: ToString]: ToString[A] = implicitly

type Typeclass[A] = ToString[A]

@typeValidation.minMembers(1)
@typeValidation.maxMembers(1)
def combine[A](ctx: ReadOnlyCaseClass[ToString, A]): ToString[A] =
(a: A) => {
val param = ctx.parameters.head
param.typeclass.str(param.dereference(a))
}

implicit def derive[A]: ToString[A] = macro Magnolia.gen[A]

implicit val str: ToString[String] = (a: String) => a
implicit val int: ToString[Int] = (a: Int) => a.toString
}


trait FromString[A] {
def fromStr(str: String): A
}

object FromString {
def apply[A: FromString]: FromString[A] = implicitly

type Typeclass[A] = FromString[A]

@typeValidation.minMembers(1)
@typeValidation.maxMembers(1)
def combine[A](ctx: CaseClass[FromString, A]): FromString[A] =
(str: String) => ctx.construct(p => p.typeclass.fromStr(str))

implicit def derive[A]: FromString[A] = macro Magnolia.gen[A]

implicit val str: FromString[String] = (a: String) => a
implicit val int: FromString[Int] = (a: String) => a.toInt

}
56 changes: 56 additions & 0 deletions src/test/tests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ final case class Huey(height: Int) extends GoodChild
class Dewey(val height: Int) extends GoodChild
final case class Louie(height: Int) extends BadChild

case class StringWrapper(str: String)
case class StringWrapWrapper(w: StringWrapper)
case class StringWrapAnyVal(w: StringWrapper) extends AnyVal
class StringWrapNotCaseClass(val w: Int) extends AnyVal
case class StringWrapperComposite(str: String, int: Int)
case object A

object Tests extends Suite("Magnolia tests") {

def run(test: Runner): Unit = for (_ <- 1 to 1) {
Expand Down Expand Up @@ -732,7 +739,56 @@ object Tests extends Suite("Magnolia tests") {
}

test("support dispatch without combine") {
NoCombine.gen[Halfy]
implicitly[NoCombine[Halfy]].nameOf(Righty())
}.assert(_ == "Righty")

test("readonly unary product types: support wrapper case class") {
ToString[StringWrapper].str(StringWrapper("a"))
}.assert(_ == "a")

test("readonly unary product types: support nested wrapper case class") {
ToString[StringWrapWrapper].str(StringWrapWrapper(StringWrapper("a")))
}.assert(_ == "a")

test("readonly unary product types: support wrapper case class extending AnyVal") {
ToString[StringWrapAnyVal].str(StringWrapAnyVal(StringWrapper("a")))
}.assert(_ == "a")

test("readonly unary product types: support non-case wrapper class") {
ToString[StringWrapNotCaseClass].str(new StringWrapNotCaseClass(1))
}.assert(_ == "1")

test("readonly unary product types: not support case object unary product type") {
scalac"ToString.derive[A.type]"
}.assert(_ == TypecheckError(txt"magnolia: magnolia.tests.A.type is not a valid type for ToString.Typeclass because at least 1 members are required (it has 0)"))

test("readonly unary product types: not support case class with two members") {
scalac"ToString.derive[StringWrapperComposite.type]"
}.assert(_ == TypecheckError(txt"magnolia: magnolia.tests.StringWrapperComposite.type is not a valid type for ToString.Typeclass because at least 1 members are required (it has 0)"))

test("unary product types: support wrapper case class") {
FromString[StringWrapper].fromStr("a")
}.assert(_ == StringWrapper("a"))

test("unary product types: support nested wrapper case class") {
FromString[StringWrapWrapper].fromStr("a")
}.assert(_ == StringWrapWrapper(StringWrapper("a")))

test("unary product types: support wrapper case class extending AnyVal") {
FromString[StringWrapAnyVal].fromStr("a")
}.assert(_ == StringWrapAnyVal(StringWrapper("a")))

test("unary product types: support non-case wrapper class") {
FromString[StringWrapNotCaseClass].fromStr("1")
}.assert(_ == new StringWrapNotCaseClass(1))

test("unary product types: not support case object unary product type") {
scalac"FromString.derive[A.type]"
}.assert(_ == TypecheckError(txt"magnolia: magnolia.tests.A.type is not a valid type for FromString.Typeclass because at least 1 members are required (it has 0)"))

test("unary product types: not support case class with two members") {
scalac"FromString.derive[StringWrapperComposite]"
}.assert(_ == TypecheckError(txt"magnolia: magnolia.tests.StringWrapperComposite is not a valid type for FromString.Typeclass because at no more than 1 members are required (it has 2)"))
}
}