Skip to content

Commit

Permalink
Add a macro-based operator for mapping projections to case classes
Browse files Browse the repository at this point in the history
The new `mapTo` operator requires less boilerplate than a traditional
case class mapping defined with `<>` and it can support case classes of
more than 22 elements by using an HList instead of a tuple on the
left-hand side. Mappings of up to 22 elements may use either one.
  • Loading branch information
szeiger committed Oct 2, 2015
1 parent 4f22634 commit d980dbb
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 15 deletions.
Expand Up @@ -147,7 +147,7 @@ class AggregateTest extends AsyncTest[RelationalTestDB] {
class T4(tag: Tag) extends Table[Pair](tag, "t4") { class T4(tag: Tag) extends Table[Pair](tag, "t4") {
def a = column[Int]("a") def a = column[Int]("a")
def b = column[Option[Int]]("b") def b = column[Option[Int]]("b")
def * = (a, b) <> (Pair.tupled,Pair.unapply) def * = (a, b).mapTo[Pair]
} }
val t4s = TableQuery[T4] val t4s = TableQuery[T4]
db.run(t4s.schema.create >> db.run(t4s.schema.create >>
Expand Down Expand Up @@ -198,7 +198,7 @@ class AggregateTest extends AsyncTest[RelationalTestDB] {
def col4 = column[Int]("COL4") def col4 = column[Int]("COL4")
def col5 = column[Int]("COL5") def col5 = column[Int]("COL5")


def * = (col1, col2, col3, col4, col5) <> (Tab.tupled, Tab.unapply) def * = (col1, col2, col3, col4, col5).mapTo[Tab]
} }
val Tabs = TableQuery[Tabs] val Tabs = TableQuery[Tabs]


Expand Down
Expand Up @@ -60,7 +60,7 @@ class JdbcMapperTest extends AsyncTest[JdbcTestDB] {
class T(tag: Tag) extends Table[Data](tag, "T") { class T(tag: Tag) extends Table[Data](tag, "T") {
def a = column[Int]("A") def a = column[Int]("A")
def b = column[Int]("B") def b = column[Int]("B")
def * = (a, b) <> (Data.tupled, Data.unapply _) def * = (a, b).mapTo[Data]
} }
val ts = TableQuery[T] val ts = TableQuery[T]


Expand All @@ -77,8 +77,16 @@ class JdbcMapperTest extends AsyncTest[JdbcTestDB] {
} }


def testWideMappedEntity = { def testWideMappedEntity = {
import slick.collection.heterogeneous._
import slick.collection.heterogeneous.syntax._

case class Part(i1: Int, i2: Int, i3: Int, i4: Int, i5: Int, i6: Int) case class Part(i1: Int, i2: Int, i3: Int, i4: Int, i5: Int, i6: Int)
case class Whole(id: Int, p1: Part, p2: Part, p3: Part, p4: Part) case class Whole(id: Int, p1: Part, p2: Part, p3: Part, p4: Part)
case class BigCase(id: Int,
p1i1: Int, p1i2: Int, p1i3: Int, p1i4: Int, p1i5: Int, p1i6: Int,
p2i1: Int, p2i2: Int, p2i3: Int, p2i4: Int, p2i5: Int, p2i6: Int,
p3i1: Int, p3i2: Int, p3i3: Int, p3i4: Int, p3i5: Int, p3i6: Int,
p4i1: Int, p4i2: Int, p4i3: Int, p4i4: Int, p4i5: Int, p4i6: Int)


class T(tag: Tag) extends Table[Whole](tag, "t_wide") { class T(tag: Tag) extends Table[Whole](tag, "t_wide") {
def id = column[Int]("id", O.PrimaryKey) def id = column[Int]("id", O.PrimaryKey)
Expand Down Expand Up @@ -106,19 +114,37 @@ class JdbcMapperTest extends AsyncTest[JdbcTestDB] {
def p4i4 = column[Int]("p4i4") def p4i4 = column[Int]("p4i4")
def p4i5 = column[Int]("p4i5") def p4i5 = column[Int]("p4i5")
def p4i6 = column[Int]("p4i6") def p4i6 = column[Int]("p4i6")
def * = ( // Composable bidirectional mappings
def m1 = (
id,
(p1i1, p1i2, p1i3, p1i4, p1i5, p1i6).mapTo[Part],
(p2i1, p2i2, p2i3, p2i4, p2i5, p2i6) <> (Part.tupled, Part.unapply _),
(p3i1, p3i2, p3i3, p3i4, p3i5, p3i6).mapTo[Part],
(p4i1, p4i2, p4i3, p4i4, p4i5, p4i6).mapTo[Part]
).mapTo[Whole]
// Manually composed mapping functions
def m2 = (
id, id,
(p1i1, p1i2, p1i3, p1i4, p1i5, p1i6), (p1i1, p1i2, p1i3, p1i4, p1i5, p1i6),
(p2i1, p2i2, p2i3, p2i4, p2i5, p2i6), (p2i1, p2i2, p2i3, p2i4, p2i5, p2i6),
(p3i1, p3i2, p3i3, p3i4, p3i5, p3i6), (p3i1, p3i2, p3i3, p3i4, p3i5, p3i6),
(p4i1, p4i2, p4i3, p4i4, p4i5, p4i6) (p4i1, p4i2, p4i3, p4i4, p4i5, p4i6)
).shaped <> ({ case (id, p1, p2, p3, p4) => ).shaped <> ({ case (id, p1, p2, p3, p4) =>
// We could do this without .shaped but then we'd have to write a type annotation for the parameters // We could do this without .shaped but then we'd have to write a type annotation for the parameters
Whole(id, Part.tupled.apply(p1), Part.tupled.apply(p2), Part.tupled.apply(p3), Part.tupled.apply(p4)) Whole(id, Part.tupled.apply(p1), Part.tupled.apply(p2), Part.tupled.apply(p3), Part.tupled.apply(p4))
}, { w: Whole => }, { w: Whole =>
def f(p: Part) = Part.unapply(p).get def f(p: Part) = Part.unapply(p).get
Some((w.id, f(w.p1), f(w.p2), f(w.p3), f(w.p4))) Some((w.id, f(w.p1), f(w.p2), f(w.p3), f(w.p4)))
}) })
// HList-based wide case class mapping
def m3 = (
id ::
p1i1 :: p1i2 :: p1i3 :: p1i4 :: p1i5 :: p1i6 ::
p2i1 :: p2i2 :: p2i3 :: p2i4 :: p2i5 :: p2i6 ::
p3i1 :: p3i2 :: p3i3 :: p3i4 :: p3i5 :: p3i6 ::
p4i1 :: p4i2 :: p4i3 :: p4i4 :: p4i5 :: p4i6 :: HNil
).mapTo[BigCase]
def * = m1
} }
val ts = TableQuery[T] val ts = TableQuery[T]


Expand All @@ -132,7 +158,9 @@ class JdbcMapperTest extends AsyncTest[JdbcTestDB] {
seq( seq(
ts.schema.create, ts.schema.create,
ts += oData, ts += oData,
ts.result.head.map(_ shouldBe oData) ts.result.head.map(_ shouldBe oData),
ts.map(_.m2).result.head.map(_ shouldBe oData),
ts.map(_.m3).result.head.map(_ shouldBe BigCase(0, 11, 12, 13, 14, 15, 16, 21, 22, 23, 24, 25, 26, 31, 32, 33, 34, 35, 36, 41, 42, 43, 44, 45, 46))
) )
} }


Expand All @@ -147,7 +175,7 @@ class JdbcMapperTest extends AsyncTest[JdbcTestDB] {
def p3 = column[String]("p3") def p3 = column[String]("p3")
def p4 = column[Int]("p4") def p4 = column[Int]("p4")
def part1 = (p1,p2) <> (Part1.tupled,Part1.unapply) def part1 = (p1,p2) <> (Part1.tupled,Part1.unapply)
def part2 = (p3,p4) <> (Part2.tupled,Part2.unapply) def part2 = (p3,p4).mapTo[Part2]
def * = (part1, part2) <> (Whole.tupled,Whole.unapply) def * = (part1, part2) <> (Whole.tupled,Whole.unapply)
} }
val T = TableQuery[T] val T = TableQuery[T]
Expand All @@ -169,14 +197,14 @@ class JdbcMapperTest extends AsyncTest[JdbcTestDB] {
class ARow(tag: Tag) extends Table[A](tag, "t4_a") { class ARow(tag: Tag) extends Table[A](tag, "t4_a") {
def id = column[Int]("id", O.PrimaryKey, O.AutoInc) def id = column[Int]("id", O.PrimaryKey, O.AutoInc)
def data = column[Int]("data") def data = column[Int]("data")
def * = (id, data) <> (A.tupled, A.unapply _) def * = (id, data).mapTo[A]
} }
val as = TableQuery[ARow] val as = TableQuery[ARow]


class BRow(tag: Tag) extends Table[B](tag, "t5_b") { class BRow(tag: Tag) extends Table[B](tag, "t5_b") {
def id = column[Int]("id", O.PrimaryKey, O.AutoInc) def id = column[Int]("id", O.PrimaryKey, O.AutoInc)
def data = column[String]("data") def data = column[String]("data")
def * = (id, Rep.Some(data)) <> (B.tupled, B.unapply _) def * = (id, Rep.Some(data)).mapTo[B]
} }
val bs = TableQuery[BRow] val bs = TableQuery[BRow]


Expand Down Expand Up @@ -297,11 +325,14 @@ class JdbcMapperTest extends AsyncTest[JdbcTestDB] {
import slick.collection.heterogeneous._ import slick.collection.heterogeneous._
import slick.collection.heterogeneous.syntax._ import slick.collection.heterogeneous.syntax._


case class Data(id: Int, b: Boolean, s: String)

class B(tag: Tag) extends Table[Int :: Boolean :: String :: HNil](tag, "hlist_b") { class B(tag: Tag) extends Table[Int :: Boolean :: String :: HNil](tag, "hlist_b") {
def id = column[Int]("id", O.PrimaryKey) def id = column[Int]("id", O.PrimaryKey)
def b = column[Boolean]("b") def b = column[Boolean]("b")
def s = column[String]("s") def s = column[String]("s")
def * = id :: b :: s :: HNil def * = id :: b :: s :: HNil
def mapped = *.mapTo[Data]
} }
val bs = TableQuery[B] val bs = TableQuery[B]


Expand All @@ -320,9 +351,10 @@ class JdbcMapperTest extends AsyncTest[JdbcTestDB] {
bs.schema.create, bs.schema.create,
bs += (1 :: true :: "a" :: HNil), bs += (1 :: true :: "a" :: HNil),
bs += (2 :: false :: "c" :: HNil), bs += (2 :: false :: "c" :: HNil),
bs += (3 :: false :: "b" :: HNil), bs.map(_.mapped) += Data(3, false, "b"),
q1.result.map(_ shouldBe Vector(3 :: "bb" :: (42 :: HNil) :: HNil, 2 :: "cc" :: (42 :: HNil) :: HNil)), q1.result.map(_ shouldBe Vector(3 :: "bb" :: (42 :: HNil) :: HNil, 2 :: "cc" :: (42 :: HNil) :: HNil)),
q1.result.map(_ shouldBe Vector(3 :: "bb" :: (42 :: HNil) :: HNil, 2 :: "cc" :: (42 :: HNil) :: HNil)) q2.result.map(_ shouldBe Vector(3 :: "bb" :: (42 :: HNil) :: HNil, 2 :: "cc" :: (42 :: HNil) :: HNil)),
bs.map(_.mapped).result.map(_.toSet shouldBe Set(Data(1, true, "a"), Data(2, false, "c"), Data(3, false, "b")))
) )
} }


Expand Down Expand Up @@ -374,7 +406,7 @@ class JdbcMapperTest extends AsyncTest[JdbcTestDB] {
class T(tag: Tag) extends Table[Data](tag, "T_fastpath") { class T(tag: Tag) extends Table[Data](tag, "T_fastpath") {
def a = column[Int]("A") def a = column[Int]("A")
def b = column[Int]("B") def b = column[Int]("B")
def * = (a, b) <> (Data.tupled, Data.unapply _) fastPath(new FastPath(_) { def * = (a, b).mapTo[Data].fastPath(new FastPath(_) {
val (a, b) = (next[Int], next[Int]) val (a, b) = (next[Int], next[Int])
override def read(r: Reader) = Data(a.read(r), b.read(r)) override def read(r: Reader) = Data(a.read(r), b.read(r))
}) })
Expand Down
Expand Up @@ -218,7 +218,7 @@ class NestingTest extends AsyncTest[RelationalTestDB] {
def id = column[Long]("id", O.PrimaryKey, O.AutoInc) def id = column[Long]("id", O.PrimaryKey, O.AutoInc)
def name = column[Option[String]]("name") def name = column[Option[String]]("name")
def popularOptions = column[Option[String]]("popularOptions") def popularOptions = column[Option[String]]("popularOptions")
def * = (name.getOrElse(""), popularOptions.getOrElse(""), id) <> (Chord.tupled, Chord.unapply) def * = (name.getOrElse(""), popularOptions.getOrElse(""), id).mapTo[Chord]
} }
val chords = TableQuery[Chords] val chords = TableQuery[Chords]
val allChords = Set(Chord("maj7", "9 #11"), Chord("m7", "9 11"), Chord("7", "9 13"), Chord("m7b5", "11"), Chord("aug7", "9"), Chord("dim7", "")) val allChords = Set(Chord("maj7", "9 #11"), Chord("m7", "9 11"), Chord("7", "9 13"), Chord("m7b5", "11"), Chord("aug7", "9"), Chord("dim7", ""))
Expand Down
Expand Up @@ -162,7 +162,7 @@ class RelationalMiscTest extends AsyncTest[RelationalTestDB] {
class A(tag: Tag) extends Table[Customer](tag, "INIT_A") { class A(tag: Tag) extends Table[Customer](tag, "INIT_A") {
def id = column[Id]("ID", O.PrimaryKey, O.AutoInc)(Tables.idMapper) def id = column[Id]("ID", O.PrimaryKey, O.AutoInc)(Tables.idMapper)
import Tables.idMapper import Tables.idMapper
def * = id.<>(Customer.apply, Customer.unapply) def * = id.mapTo[Customer]
} }
Tables.as.schema Tables.as.schema


Expand Down
45 changes: 44 additions & 1 deletion slick/src/main/scala/slick/lifted/Shape.scala
@@ -1,8 +1,10 @@
package slick.lifted package slick.lifted


import scala.language.{existentials, implicitConversions, higherKinds} import scala.language.{existentials, implicitConversions, higherKinds}
import scala.language.experimental.macros
import scala.annotation.implicitNotFound import scala.annotation.implicitNotFound
import scala.annotation.unchecked.uncheckedVariance import scala.annotation.unchecked.uncheckedVariance
import scala.reflect.macros.blackbox.Context
import slick.SlickException import slick.SlickException
import slick.util.{ConstArray, ProductWrapper, TupleSupport} import slick.util.{ConstArray, ProductWrapper, TupleSupport}
import slick.ast._ import slick.ast._
Expand Down Expand Up @@ -269,12 +271,53 @@ case class ShapedValue[T, U](value: T, shape: Shape[_ <: FlatShapeLevel, T, U, _
def toNode = shape.toNode(value) def toNode = shape.toNode(value)
def packedValue[R](implicit ev: Shape[_ <: FlatShapeLevel, T, _, R]): ShapedValue[R, U] = ShapedValue(shape.pack(value).asInstanceOf[R], shape.packedShape.asInstanceOf[Shape[FlatShapeLevel, R, U, _]]) def packedValue[R](implicit ev: Shape[_ <: FlatShapeLevel, T, _, R]): ShapedValue[R, U] = ShapedValue(shape.pack(value).asInstanceOf[R], shape.packedShape.asInstanceOf[Shape[FlatShapeLevel, R, U, _]])
def zip[T2, U2](s2: ShapedValue[T2, U2]) = new ShapedValue[(T, T2), (U, U2)]((value, s2.value), Shape.tuple2Shape(shape, s2.shape)) def zip[T2, U2](s2: ShapedValue[T2, U2]) = new ShapedValue[(T, T2), (U, U2)]((value, s2.value), Shape.tuple2Shape(shape, s2.shape))
@inline def <>[R : ClassTag](f: (U => R), g: (R => Option[U])) = new MappedProjection[R, U](shape.toNode(value), MappedScalaType.Mapper(g.andThen(_.get).asInstanceOf[Any => Any], f.asInstanceOf[Any => Any], None), implicitly[ClassTag[R]]) def <>[R : ClassTag](f: (U => R), g: (R => Option[U])) = new MappedProjection[R, U](shape.toNode(value), MappedScalaType.Mapper(g.andThen(_.get).asInstanceOf[Any => Any], f.asInstanceOf[Any => Any], None), implicitly[ClassTag[R]])
@inline def shaped: ShapedValue[T, U] = this @inline def shaped: ShapedValue[T, U] = this

def mapTo[R <: Product with Serializable](implicit rCT: ClassTag[R]): MappedProjection[R, U] = macro ShapedValue.mapToImpl[R, U]
} }


object ShapedValue { object ShapedValue {
@inline implicit def shapedValueShape[T, U, Level <: ShapeLevel] = RepShape[Level, ShapedValue[T, U], U] @inline implicit def shapedValueShape[T, U, Level <: ShapeLevel] = RepShape[Level, ShapedValue[T, U], U]

def mapToImpl[R <: Product with Serializable, U](c: Context { type PrefixType = ShapedValue[_, U] })(rCT: c.Expr[ClassTag[R]])(implicit rTag: c.WeakTypeTag[R], uTag: c.WeakTypeTag[U]): c.Tree = {
import c.universe._
val rSym = symbolOf[R]
if(!rSym.isClass || !rSym.asClass.isCaseClass)
c.abort(c.enclosingPosition, s"${rSym.fullName} must be a case class")
val rModule = rSym.companion match {
case NoSymbol => q"${rSym.name.toTermName}" // This can happen for case classes defined inside of methods
case s => q"$s"
}
val caseFields = rTag.tpe.decls.collect {
case s: TermSymbol if s.isVal && s.isCaseAccessor => (TermName(s.name.toString.trim), s.typeSignature)
}.toIndexedSeq
val (f, g) = if(uTag.tpe <:< c.typeOf[slick.collection.heterogeneous.HList]) { // Map from HList
val rTypeAsHList = caseFields.foldRight[Tree](tq"_root_.slick.collection.heterogeneous.HNil.type") {
case ((_, t), z) => tq"_root_.slick.collection.heterogeneous.HCons[$t, $z]"
}
val matchNames = caseFields.map(_ => TermName(c.freshName()))
val pat = matchNames.foldRight[Tree](pq"_root_.slick.collection.heterogeneous.HNil") {
case (n, z) => pq"_root_.slick.collection.heterogeneous.HCons($n, $z)"
}
val cons = caseFields.foldRight[Tree](q"_root_.slick.collection.heterogeneous.HNil") {
case ((n, _), z) => q"v.$n :: $z"
}
(q"({ case $pat => new $rTag(..$matchNames) } : ($rTypeAsHList => $rTag)): ($uTag => $rTag)",
q"{ case v => $cons }: ($rTag => $uTag)")
} else if(caseFields.length == 1) { // Map from single value
(q"($rModule.apply _) : ($uTag => $rTag)",
q"(($rModule.unapply _) : $rTag => Option[$uTag]).andThen(_.get)")
} else { // Map from tuple
(q"($rModule.tupled) : ($uTag => $rTag)",
q"(($rModule.unapply _) : $rTag => Option[$uTag]).andThen(_.get)")
}
q"""val ff = $f // Resolving f first creates more useful type errors
new _root_.slick.lifted.MappedProjection[$rTag, $uTag](${c.prefix}.toNode,
_root_.slick.ast.MappedScalaType.Mapper($g.asInstanceOf[Any => Any], ff.asInstanceOf[Any => Any], _root_.scala.None),
$rCT
)"""
}
} }


/** A limited version of ShapedValue which can be constructed for every type /** A limited version of ShapedValue which can be constructed for every type
Expand Down

2 comments on commit d980dbb

@ikhoon
Copy link

@ikhoon ikhoon commented on d980dbb Aug 7, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, awesome!! 👍

@phderome
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this commit address some complaints raised on slow compile time when using Shapeless as solution to function +22 params as per this discussion milessabin/shapeless#619 ?
Said differently is the compile time good with case classes of about 100 columns using this build of Slick?

Please sign in to comment.