Skip to content

Commit

Permalink
Make mirrored type a type member
Browse files Browse the repository at this point in the history
Make mirrored type a type member instead of a type parameter, following
Miles' design.
  • Loading branch information
odersky committed May 18, 2019
1 parent 314567a commit d5f49b5
Showing 1 changed file with 56 additions and 37 deletions.
93 changes: 56 additions & 37 deletions tests/run/typeclass-derivation2d.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,23 @@ object Deriving {
/** The Generic class hierarchy allows typelevel access to
* enums, case classes and objects, and their sealed parents.
*/
sealed abstract class Mirror[T]
sealed abstract class Mirror { type MirroredType }

type MirrorOf[T] = Mirror { type MirroredType = T }

object Mirror {

/** The Mirror for a sum type */
trait Sum[T] extends Mirror[T] { self =>
trait Sum extends Mirror { self =>

type ElemTypes <: Tuple

/** The ordinal number of the case class of `x`. For enums, `ordinal(x) == x.ordinal` */
def ordinal(x: T): Int
def ordinal(x: MirroredType): Int
}

/** The Mirror for a product type */
trait Product[T] extends Mirror[T] {
trait Product extends Mirror {

/** The types of the elements */
type ElemTypes <: Tuple
Expand All @@ -37,8 +39,11 @@ object Deriving {
type ElemLabels <: Tuple

/** Create a new instance of type `T` with elements taken from product `p`. */
def fromProduct(p: scala.Product): T
def fromProduct(p: scala.Product): MirroredType
}

type SumOf[T] = Sum { type MirroredType = T }
type ProductOf[T] = Product { type MirroredType = T }
}

/** Helper class to turn arrays into products */
Expand All @@ -61,37 +66,43 @@ import Deriving._

sealed trait Lst[+T] // derives Eq, Pickler, Show

object Lst extends Mirror.Sum[Lst[_]] {
object Lst extends Mirror.Sum {
type MirroredType = Lst[_]

def ordinal(x: Lst[_]) = x match {
case x: Cons[_] => 0
case Nil => 1
}

implicit def mirror[T]: Mirror.Sum[Lst[T]] {
implicit def mirror[T]: Mirror.Sum {
type MirroredType = Lst[T]
type ElemTypes = (Cons[T], Nil.type)
} = this.asInstanceOf

case class Cons[T](hd: T, tl: Lst[T]) extends Lst[T]

object Cons extends Mirror.Product[Cons[_]] {
object Cons extends Mirror.Product {
type MirroredType = Cons[_]

def apply[T](x: T, xs: Lst[T]): Lst[T] = new Cons(x, xs)

def fromProduct(p: Product): Cons[_] =
new Cons(productElement[Any](p, 0), productElement[Lst[Any]](p, 1))

implicit def mirror[T]: Mirror.Product[Cons[T]] {
implicit def mirror[T]: Mirror.Product {
type MirroredType = Cons[T]
type ElemTypes = (T, Lst[T])
type CaseLabel = "Cons"
type ElemLabels = ("hd", "tl")
} = this.asInstanceOf
}

case object Nil extends Lst[Nothing] with Mirror.Product[Nil.type] {
case object Nil extends Lst[Nothing] with Mirror.Product {
type MirroredType = Nil.type
def fromProduct(p: Product): Nil.type = Nil

implicit def mirror: Mirror.Product[Nil.type] {
implicit def mirror: Mirror.Product {
type MirroredType = Nil.type
type ElemTypes = Unit
type CaseLabel = "Nil"
type ElemLabels = Unit
Expand All @@ -108,12 +119,14 @@ object Lst extends Mirror.Sum[Lst[_]] {

case class Pair[T](x: T, y: T) // derives Eq, Pickler, Show

object Pair extends Mirror.Product[Pair[_]] {
object Pair extends Mirror.Product {
type MirroredType = Pair[_]

def fromProduct(p: Product): Pair[_] =
Pair(productElement[Any](p, 0), productElement[Any](p, 1))

implicit def mirror[T]: Mirror.Product[Pair[T]] {
implicit def mirror[T]: Mirror.Product {
type MirroredType = Pair[T]
type ElemTypes = (T, T)
type CaseLabel = "Pair"
type ElemLabels = ("x", "y")
Expand All @@ -129,14 +142,16 @@ object Pair extends Mirror.Product[Pair[_]] {

sealed trait Either[+L, +R] extends Product with Serializable // derives Eq, Pickler, Show

object Either extends Mirror.Sum[Either[_, _]] {
object Either extends Mirror.Sum {
type MirroredType = Either[_, _]

def ordinal(x: Either[_, _]) = x match {
case x: Left[_] => 0
case x: Right[_] => 1
}

implicit def mirror[L, R]: Mirror.Sum[Either[L, R]] {
implicit def mirror[L, R]: Mirror.Sum {
type MirroredType = Either[L, R]
type ElemTypes = (Left[L], Right[R])
} = this.asInstanceOf

Expand All @@ -148,18 +163,22 @@ object Either extends Mirror.Sum[Either[_, _]] {
case class Left[L](elem: L) extends Either[L, Nothing]
case class Right[R](elem: R) extends Either[Nothing, R]

object Left extends Mirror.Product[Left[_]] {
object Left extends Mirror.Product {
type MirroredType = Left[_]
def fromProduct(p: Product): Left[_] = Left(productElement[Any](p, 0))
implicit def mirror[L]: Mirror.Product[Left[L]] {
implicit def mirror[L]: Mirror.Product {
type MirroredType = Left[L]
type ElemTypes = L *: Unit
type CaseLabel = "Left"
type ElemLabels = "x" *: Unit
} = this.asInstanceOf
}

object Right extends Mirror.Product[Right[_]] {
object Right extends Mirror.Product {
type MirroredType = Right[_]
def fromProduct(p: Product): Right[_] = Right(productElement[Any](p, 0))
implicit def mirror[R]: Mirror.Product[Right[R]] {
implicit def mirror[R]: Mirror.Product {
type MirroredType = Right[R]
type ElemTypes = R *: Unit
type CaseLabel = "Right"
type ElemLabels = "x" *: Unit
Expand Down Expand Up @@ -188,28 +207,28 @@ object Eq {
true
}

inline def eqlProduct[T](m: Mirror.Product[T])(x: Any, y: Any): Boolean =
inline def eqlProduct[T](m: Mirror.ProductOf[T])(x: Any, y: Any): Boolean =
eqlElems[m.ElemTypes](0)(x, y)

inline def eqlCases[Alts](n: Int)(x: Any, y: Any, ord: Int): Boolean =
inline erasedValue[Alts] match {
case _: (alt *: alts1) =>
if (ord == n)
implicit match {
case m: Mirror.Product[`alt`] => eqlElems[m.ElemTypes](0)(x, y)
case m: Mirror.ProductOf[`alt`] => eqlElems[m.ElemTypes](0)(x, y)
}
else eqlCases[alts1](n + 1)(x, y, ord)
case _: Unit =>
false
}

inline def derived[T](implicit ev: Mirror[T]): Eq[T] = new Eq[T] {
inline def derived[T](implicit ev: MirrorOf[T]): Eq[T] = new Eq[T] {
def eql(x: T, y: T): Boolean =
inline ev match {
case m: Mirror.Sum[T] =>
case m: Mirror.SumOf[T] =>
val ord = m.ordinal(x)
ord == m.ordinal(y) && eqlCases[m.ElemTypes](0)(x, y, ord)
case m: Mirror.Product[T] =>
case m: Mirror.ProductOf[T] =>
eqlElems[m.ElemTypes](0)(x, y)
}
}
Expand Down Expand Up @@ -248,7 +267,7 @@ object Pickler {
case _: (alt *: alts1) =>
if (ord == n)
implicit match {
case m: Mirror.Product[`alt`] => pickleElems[m.ElemTypes](0)(buf, x)
case m: Mirror.ProductOf[`alt`] => pickleElems[m.ElemTypes](0)(buf, x)
}
else pickleCases[alts1](n + 1)(buf, x, ord)
case _: Unit =>
Expand All @@ -266,7 +285,7 @@ object Pickler {
case _: Unit =>
}

inline def unpickleCase[T, Elems <: Tuple](buf: mutable.ListBuffer[Int], m: Mirror.Product[T]): T = {
inline def unpickleCase[T, Elems <: Tuple](buf: mutable.ListBuffer[Int], m: Mirror.ProductOf[T]): T = {
inline val size = constValue[Tuple.Size[Elems]]
inline if (size == 0)
m.fromProduct(EmptyProduct)
Expand All @@ -282,30 +301,30 @@ object Pickler {
case _: (alt *: alts1) =>
if (ord == n)
implicit match {
case m: Mirror.Product[`alt` & T] =>
case m: Mirror.ProductOf[`alt` & T] =>
unpickleCase[`alt` & T, m.ElemTypes](buf, m)
}
else unpickleCases[T, alts1](n + 1)(buf, ord)
case _: Unit =>
throw new IndexOutOfBoundsException(s"unexpected ordinal number: $ord")
}

inline def derived[T](implicit ev: Mirror[T]): Pickler[T] = new {
inline def derived[T](implicit ev: MirrorOf[T]): Pickler[T] = new {
def pickle(buf: mutable.ListBuffer[Int], x: T): Unit =
inline ev match {
case m: Mirror.Sum[T] =>
case m: Mirror.SumOf[T] =>
val ord = m.ordinal(x)
buf += ord
pickleCases[m.ElemTypes](0)(buf, x, ord)
case m: Mirror.Product[T] =>
case m: Mirror.ProductOf[T] =>
pickleElems[m.ElemTypes](0)(buf, x)
}
def unpickle(buf: mutable.ListBuffer[Int]): T =
inline ev match {
case m: Mirror.Sum[T] =>
case m: Mirror.SumOf[T] =>
val ord = nextInt(buf)
unpickleCases[T, m.ElemTypes](0)(buf, ord)
case m: Mirror.Product[T] =>
case m: Mirror.ProductOf[T] =>
unpickleCase[T, m.ElemTypes](buf, m)
}
}
Expand Down Expand Up @@ -341,7 +360,7 @@ object Show {
Nil
}

inline def showCase(x: Any, m: Mirror.Product[_]): String = {
inline def showCase(x: Any, m: Mirror.ProductOf[_]): String = {
val label = constValue[m.CaseLabel]
showElems[m.ElemTypes, m.ElemLabels](0)(x).mkString(s"$label(", ", ", ")")
}
Expand All @@ -351,21 +370,21 @@ object Show {
case _: (alt *: alts1) =>
if (ord == n)
implicit match {
case m: Mirror.Product[`alt`] =>
case m: Mirror.ProductOf[`alt`] =>
showCase(x, m)
}
else showCases[alts1](n + 1)(x, ord)
case _: Unit =>
throw new MatchError(x)
}

inline def derived[T](implicit ev: Mirror[T]): Show[T] = new {
inline def derived[T](implicit ev: MirrorOf[T]): Show[T] = new {
def show(x: T): String =
inline ev match {
case m: Mirror.Sum[T] =>
case m: Mirror.SumOf[T] =>
val ord = m.ordinal(x)
showCases[m.ElemTypes](0)(x, ord)
case m: Mirror.Product[T] =>
case m: Mirror.ProductOf[T] =>
showCase(x, m)
}
}
Expand Down

0 comments on commit d5f49b5

Please sign in to comment.