Skip to content

Commit

Permalink
Special treatment for Singletons
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky committed May 18, 2019
1 parent 1a709f1 commit c77fdcc
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 43 deletions.
8 changes: 4 additions & 4 deletions tests/run/typeclass-derivation2d.check
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ ListBuffer(0, 0, 11, 0, 22, 0, 33, 1, 0, 0, 11, 0, 22, 1, 1)
Cons(Cons(11,Cons(22,Cons(33,Nil))),Cons(Cons(11,Cons(22,Nil)),Nil))
ListBuffer(1, 2)
Pair(1,2)
Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil())))
Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil()))), tl = Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Nil())), tl = Nil()))
Cons(hd = Left(x = 1), tl = Cons(hd = Right(x = Pair(x = 2, y = 3)), tl = Nil()))
Cons(hd = Left(x = 1), tl = Cons(hd = Right(x = Pair(x = 2, y = 3)), tl = Nil()))
Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil)))
Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil))), tl = Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Nil)), tl = Nil))
Cons(hd = Left(x = 1), tl = Cons(hd = Right(x = Pair(x = 2, y = 3)), tl = Nil))
Cons(hd = Left(x = 1), tl = Cons(hd = Right(x = Pair(x = 2, y = 3)), tl = Nil))
106 changes: 67 additions & 39 deletions tests/run/typeclass-derivation2d.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,29 @@ object Deriving {
/** The Generic class hierarchy allows typelevel access to
* enums, case classes and objects, and their sealed parents.
*/
sealed abstract class Mirror[T]
sealed abstract class Mirror {

/** The mirrored *-type */
type MonoType
}
type MirrorOf[T] = Mirror { type MonoType = T }
type ProductMirrorOf[T] = Mirror.Product { type MonoType = T }
type SumMirrorOf[T] = Mirror.Sum { type MonoType = T }
type SingletonMirror = Mirror.Singleton

object Mirror {

/** The Mirror for a sum type */
trait Sum[T] extends Mirror[T] { self =>
trait Sum extends Mirror { self =>

type ElemTypes <: Tuple

/** The ordinal number of the case class of `x`. For enums, `ordinal(x) == x.ordinal` */
def ordinal(x: T): Int
def ordinal(x: MonoType): Int
}

/** The Mirror for a product type */
trait Product[T] extends Mirror[T] {
trait Product extends Mirror {

/** The types of the elements */
type ElemTypes <: Tuple
Expand All @@ -37,7 +45,12 @@ object Deriving {
type ElemLabels <: Tuple

/** Create a new instance of type `T` with elements taken from product `p`. */
def fromProduct(p: scala.Product): T
def fromProduct(p: scala.Product): MonoType
}

trait Singleton extends Product {
type MonoType = this.type
def fromProduct(p: scala.Product) = this
}
}

Expand All @@ -61,37 +74,41 @@ import Deriving._

sealed trait Lst[+T] // derives Eq, Pickler, Show

object Lst extends Mirror.Sum[Lst[_]] {
object Lst extends Mirror.Sum {
type MonoType = Lst[_]

def ordinal(x: Lst[_]) = x match {
case x: Cons[_] => 0
case Nil => 1
}

implicit def mirror[T]: Mirror.Sum[Lst[T]] {
implicit def mirror[T]: Mirror.Sum {
type MonoType = Lst[T]
type ElemTypes = (Cons[T], Nil.type)
} = this.asInstanceOf

case class Cons[T](hd: T, tl: Lst[T]) extends Lst[T]

object Cons extends Mirror.Product[Cons[_]] {
object Cons extends Mirror.Product {
type MonoType = Lst[_]

def apply[T](x: T, xs: Lst[T]): Lst[T] = new Cons(x, xs)

def fromProduct(p: Product): Cons[_] =
new Cons(productElement[Any](p, 0), productElement[Lst[Any]](p, 1))

implicit def mirror[T]: Mirror.Product[Cons[T]] {
implicit def mirror[T]: Mirror.Product {
type MonoType = Cons[T]
type ElemTypes = (T, Lst[T])
type CaseLabel = "Cons"
type ElemLabels = ("hd", "tl")
} = this.asInstanceOf
}

case object Nil extends Lst[Nothing] with Mirror.Product[Nil.type] {
def fromProduct(p: Product): Nil.type = Nil
case object Nil extends Lst[Nothing] with Mirror.Singleton {

implicit def mirror: Mirror.Product[Nil.type] {
implicit def mirror: Mirror.Singleton {
type MonoType = Nil.type
type ElemTypes = Unit
type CaseLabel = "Nil"
type ElemLabels = Unit
Expand All @@ -108,12 +125,14 @@ object Lst extends Mirror.Sum[Lst[_]] {

case class Pair[T](x: T, y: T) // derives Eq, Pickler, Show

object Pair extends Mirror.Product[Pair[_]] {
object Pair extends Mirror.Product {
type MonoType = Pair[_]

def fromProduct(p: Product): Pair[_] =
Pair(productElement[Any](p, 0), productElement[Any](p, 1))

implicit def mirror[T]: Mirror.Product[Pair[T]] {
implicit def mirror[T]: Mirror.Product {
type MonoType = Pair[T]
type ElemTypes = (T, T)
type CaseLabel = "Pair"
type ElemLabels = ("x", "y")
Expand All @@ -129,14 +148,16 @@ object Pair extends Mirror.Product[Pair[_]] {

sealed trait Either[+L, +R] extends Product with Serializable // derives Eq, Pickler, Show

object Either extends Mirror.Sum[Either[_, _]] {
object Either extends Mirror.Sum {
type MonoType = Either[_, _]

def ordinal(x: Either[_, _]) = x match {
case x: Left[_] => 0
case x: Right[_] => 1
}

implicit def mirror[L, R]: Mirror.Sum[Either[L, R]] {
implicit def mirror[L, R]: Mirror.Sum {
type MonoType = Either[L, R]
type ElemTypes = (Left[L], Right[R])
} = this.asInstanceOf

Expand All @@ -148,18 +169,22 @@ object Either extends Mirror.Sum[Either[_, _]] {
case class Left[L](elem: L) extends Either[L, Nothing]
case class Right[R](elem: R) extends Either[Nothing, R]

object Left extends Mirror.Product[Left[_]] {
object Left extends Mirror.Product {
type MonoType = Left[_]
def fromProduct(p: Product): Left[_] = Left(productElement[Any](p, 0))
implicit def mirror[L]: Mirror.Product[Left[L]] {
implicit def mirror[L]: Mirror.Product {
type MonoType = Left[L]
type ElemTypes = L *: Unit
type CaseLabel = "Left"
type ElemLabels = "x" *: Unit
} = this.asInstanceOf
}

object Right extends Mirror.Product[Right[_]] {
object Right extends Mirror.Product {
type MonoType = Right[_]
def fromProduct(p: Product): Right[_] = Right(productElement[Any](p, 0))
implicit def mirror[R]: Mirror.Product[Right[R]] {
implicit def mirror[R]: Mirror.Product {
type MonoType = Right[R]
type ElemTypes = R *: Unit
type CaseLabel = "Right"
type ElemLabels = "x" *: Unit
Expand Down Expand Up @@ -188,28 +213,28 @@ object Eq {
true
}

inline def eqlProduct[T](m: Mirror.Product[T])(x: Any, y: Any): Boolean =
inline def eqlProduct[T](m: ProductMirrorOf[T])(x: Any, y: Any): Boolean =
eqlElems[m.ElemTypes](0)(x, y)

inline def eqlCases[Alts](n: Int)(x: Any, y: Any, ord: Int): Boolean =
inline erasedValue[Alts] match {
case _: (alt *: alts1) =>
if (ord == n)
implicit match {
case m: Mirror.Product[`alt`] => eqlElems[m.ElemTypes](0)(x, y)
case m: ProductMirrorOf[`alt`] => eqlElems[m.ElemTypes](0)(x, y)
}
else eqlCases[alts1](n + 1)(x, y, ord)
case _: Unit =>
false
}

inline def derived[T](implicit ev: Mirror[T]): Eq[T] = new Eq[T] {
inline def derived[T](implicit ev: MirrorOf[T]): Eq[T] = new Eq[T] {
def eql(x: T, y: T): Boolean =
inline ev match {
case m: Mirror.Sum[T] =>
case m: SumMirrorOf[T] =>
val ord = m.ordinal(x)
ord == m.ordinal(y) && eqlCases[m.ElemTypes](0)(x, y, ord)
case m: Mirror.Product[T] =>
case m: ProductMirrorOf[T] =>
eqlElems[m.ElemTypes](0)(x, y)
}
}
Expand Down Expand Up @@ -248,7 +273,7 @@ object Pickler {
case _: (alt *: alts1) =>
if (ord == n)
implicit match {
case m: Mirror.Product[`alt`] => pickleElems[m.ElemTypes](0)(buf, x)
case m: ProductMirrorOf[`alt`] => pickleElems[m.ElemTypes](0)(buf, x)
}
else pickleCases[alts1](n + 1)(buf, x, ord)
case _: Unit =>
Expand All @@ -266,7 +291,7 @@ object Pickler {
case _: Unit =>
}

inline def unpickleCase[T, Elems <: Tuple](buf: mutable.ListBuffer[Int], m: Mirror.Product[T]): T = {
inline def unpickleCase[T, Elems <: Tuple](buf: mutable.ListBuffer[Int], m: ProductMirrorOf[T]): T = {
inline val size = constValue[Tuple.Size[Elems]]
inline if (size == 0)
m.fromProduct(EmptyProduct)
Expand All @@ -282,30 +307,30 @@ object Pickler {
case _: (alt *: alts1) =>
if (ord == n)
implicit match {
case m: Mirror.Product[`alt` & T] =>
case m: ProductMirrorOf[`alt` & T] =>
unpickleCase[`alt` & T, m.ElemTypes](buf, m)
}
else unpickleCases[T, alts1](n + 1)(buf, ord)
case _: Unit =>
throw new IndexOutOfBoundsException(s"unexpected ordinal number: $ord")
}

inline def derived[T](implicit ev: Mirror[T]): Pickler[T] = new {
inline def derived[T](implicit ev: MirrorOf[T]): Pickler[T] = new {
def pickle(buf: mutable.ListBuffer[Int], x: T): Unit =
inline ev match {
case m: Mirror.Sum[T] =>
case m: SumMirrorOf[T] =>
val ord = m.ordinal(x)
buf += ord
pickleCases[m.ElemTypes](0)(buf, x, ord)
case m: Mirror.Product[T] =>
case m: ProductMirrorOf[T] =>
pickleElems[m.ElemTypes](0)(buf, x)
}
def unpickle(buf: mutable.ListBuffer[Int]): T =
inline ev match {
case m: Mirror.Sum[T] =>
case m: SumMirrorOf[T] =>
val ord = nextInt(buf)
unpickleCases[T, m.ElemTypes](0)(buf, ord)
case m: Mirror.Product[T] =>
case m: ProductMirrorOf[T] =>
unpickleCase[T, m.ElemTypes](buf, m)
}
}
Expand Down Expand Up @@ -341,31 +366,34 @@ object Show {
Nil
}

inline def showCase(x: Any, m: Mirror.Product[_]): String = {
inline def showCase(x: Any, m: ProductMirrorOf[_]): String = {
val label = constValue[m.CaseLabel]
showElems[m.ElemTypes, m.ElemLabels](0)(x).mkString(s"$label(", ", ", ")")
inline m match {
case m: SingletonMirror => label
case _ => showElems[m.ElemTypes, m.ElemLabels](0)(x).mkString(s"$label(", ", ", ")")
}
}

inline def showCases[Alts <: Tuple](n: Int)(x: Any, ord: Int): String =
inline erasedValue[Alts] match {
case _: (alt *: alts1) =>
if (ord == n)
implicit match {
case m: Mirror.Product[`alt`] =>
case m: ProductMirrorOf[`alt`] =>
showCase(x, m)
}
else showCases[alts1](n + 1)(x, ord)
case _: Unit =>
throw new MatchError(x)
}

inline def derived[T](implicit ev: Mirror[T]): Show[T] = new {
inline def derived[T](implicit ev: MirrorOf[T]): Show[T] = new {
def show(x: T): String =
inline ev match {
case m: Mirror.Sum[T] =>
case m: SumMirrorOf[T] =>
val ord = m.ordinal(x)
showCases[m.ElemTypes](0)(x, ord)
case m: Mirror.Product[T] =>
case m: ProductMirrorOf[T] =>
showCase(x, m)
}
}
Expand Down

0 comments on commit c77fdcc

Please sign in to comment.