Skip to content

Commit

Permalink
Early transformation of monadic joins into applicative joins.
Browse files Browse the repository at this point in the history
- A new compiler phase `rewriteJoins`, which runs directly after
  `flattenProjections`, converts all monadic joins into applicative
  joins. In its current state it is good enough to pass all tests.
  The implementation is much simpler than the old one in
  `fuseComprehensions` (the AST is still purely functional at this
  point) and should be easy to extend to cover more cases.

- A new compiler phase `verifySymbols` runs after `rewriteJoins` to
  check that all joins have been properly rewritten and that all
  references are reachable. When a join could not be transformed, this
  will fail with a useful error message instead of producing invalid
  SQL code.

- `createResultSetMapping` is now run at the end of the standard phases.
  It uses the original result type stored by the new `removeMappedTypes`
  phase, thus keeping the AST free of the client-side parts until the
  point where they actually match the server side.

- `fuseComprehensions` requires a small change to translate aggregations
  arising from `rewriteJoins`. This is only to make it work for now. The
  whole phase needs to be rewritten from scratch.

- Better code generator for explicit join syntax. The default in
  JdbcStatementBuilderComponent implements the standard SQL syntax.
  MySQL, SQLite and Hsqldb require special handling (in particular
  Hsqldb, which does not support arbitrary nesting on the RHS of a
  join).
  • Loading branch information
szeiger committed Jun 22, 2015
1 parent f53f06e commit 2b14139
Show file tree
Hide file tree
Showing 24 changed files with 525 additions and 118 deletions.
4 changes: 3 additions & 1 deletion common-test-resources/logback.xml
Expand Up @@ -16,13 +16,15 @@
<logger name="slick.compiler.AssignUniqueSymbols" level="${log.qcomp.assignUniqueSymbols:-inherited}" /> <logger name="slick.compiler.AssignUniqueSymbols" level="${log.qcomp.assignUniqueSymbols:-inherited}" />
<logger name="slick.compiler.InferTypes" level="${log.qcomp.inferTypes:-inherited}" /> <logger name="slick.compiler.InferTypes" level="${log.qcomp.inferTypes:-inherited}" />
<logger name="slick.compiler.ExpandTables" level="${log.qcomp.expandTables:-inherited}" /> <logger name="slick.compiler.ExpandTables" level="${log.qcomp.expandTables:-inherited}" />
<logger name="slick.compiler.CreateResultSetMapping" level="${log.qcomp.createResultSetMapping:-inherited}" />
<logger name="slick.compiler.EmulateOuterJoins" level="${log.qcomp.emulateOuterJoins:-inherited}" /> <logger name="slick.compiler.EmulateOuterJoins" level="${log.qcomp.emulateOuterJoins:-inherited}" />
<logger name="slick.compiler.ForceOuterBinds" level="${log.qcomp.forceOuterBinds:-inherited}" /> <logger name="slick.compiler.ForceOuterBinds" level="${log.qcomp.forceOuterBinds:-inherited}" />
<logger name="slick.compiler.RemoveMappedTypes" level="${log.qcomp.removeMappedTypes:-inherited}" />
<logger name="slick.compiler.CreateResultSetMapping" level="${log.qcomp.createResultSetMapping:-inherited}" />
<logger name="slick.compiler.ExpandSums" level="${log.qcomp.expandSums:-inherited}" /> <logger name="slick.compiler.ExpandSums" level="${log.qcomp.expandSums:-inherited}" />
<logger name="slick.compiler.ExpandRecords" level="${log.qcomp.expandRecords:-inherited}" /> <logger name="slick.compiler.ExpandRecords" level="${log.qcomp.expandRecords:-inherited}" />
<logger name="slick.compiler.ExpandConditionals" level="${log.qcomp.expandConditionals:-inherited}" /> <logger name="slick.compiler.ExpandConditionals" level="${log.qcomp.expandConditionals:-inherited}" />
<logger name="slick.compiler.FlattenProjections" level="${log.qcomp.flattenProjections:-inherited}" /> <logger name="slick.compiler.FlattenProjections" level="${log.qcomp.flattenProjections:-inherited}" />
<logger name="slick.compiler.RewriteJoins" level="${log.qcomp.rewriteJoins:-inherited}" />
<logger name="slick.compiler.RelabelUnions" level="${log.qcomp.relabelUnions:-inherited}" /> <logger name="slick.compiler.RelabelUnions" level="${log.qcomp.relabelUnions:-inherited}" />
<logger name="slick.compiler.PruneFields" level="${log.qcomp.pruneFields:-inherited}" /> <logger name="slick.compiler.PruneFields" level="${log.qcomp.pruneFields:-inherited}" />
<logger name="slick.compiler.ResolveZipJoins" level="${log.qcomp.resolveZipJoins:-inherited}" /> <logger name="slick.compiler.ResolveZipJoins" level="${log.qcomp.resolveZipJoins:-inherited}" />
Expand Down
4 changes: 2 additions & 2 deletions project/Build.scala
Expand Up @@ -13,8 +13,8 @@ import de.johoop.testngplugin.TestNGPlugin._


object SlickBuild extends Build { object SlickBuild extends Build {


val slickVersion = "3.0.0-RC3" val slickVersion = "3.1.0-SNAPSHOT"
val binaryCompatSlickVersion = "3.0.0" // Slick base version for binary compatibility checks val binaryCompatSlickVersion = "3.1.0" // Slick base version for binary compatibility checks
val scalaVersions = Seq("2.10.5", "2.11.6") val scalaVersions = Seq("2.10.5", "2.11.6")


/** Dependencies for reuse in different parts of the build */ /** Dependencies for reuse in different parts of the build */
Expand Down
4 changes: 3 additions & 1 deletion slick-testkit/src/doctest/resources/logback.xml
Expand Up @@ -16,13 +16,15 @@
<logger name="slick.compiler.AssignUniqueSymbols" level="${log.qcomp.assignUniqueSymbols:-inherited}" /> <logger name="slick.compiler.AssignUniqueSymbols" level="${log.qcomp.assignUniqueSymbols:-inherited}" />
<logger name="slick.compiler.InferTypes" level="${log.qcomp.inferTypes:-inherited}" /> <logger name="slick.compiler.InferTypes" level="${log.qcomp.inferTypes:-inherited}" />
<logger name="slick.compiler.ExpandTables" level="${log.qcomp.expandTables:-inherited}" /> <logger name="slick.compiler.ExpandTables" level="${log.qcomp.expandTables:-inherited}" />
<logger name="slick.compiler.CreateResultSetMapping" level="${log.qcomp.createResultSetMapping:-inherited}" />
<logger name="slick.compiler.EmulateOuterJoins" level="${log.qcomp.emulateOuterJoins:-inherited}" /> <logger name="slick.compiler.EmulateOuterJoins" level="${log.qcomp.emulateOuterJoins:-inherited}" />
<logger name="slick.compiler.ForceOuterBinds" level="${log.qcomp.forceOuterBinds:-inherited}" /> <logger name="slick.compiler.ForceOuterBinds" level="${log.qcomp.forceOuterBinds:-inherited}" />
<logger name="slick.compiler.RemoveMappedTypes" level="${log.qcomp.removeMappedTypes:-inherited}" />
<logger name="slick.compiler.CreateResultSetMapping" level="${log.qcomp.createResultSetMapping:-inherited}" />
<logger name="slick.compiler.ExpandSums" level="${log.qcomp.expandSums:-inherited}" /> <logger name="slick.compiler.ExpandSums" level="${log.qcomp.expandSums:-inherited}" />
<logger name="slick.compiler.ExpandRecords" level="${log.qcomp.expandRecords:-inherited}" /> <logger name="slick.compiler.ExpandRecords" level="${log.qcomp.expandRecords:-inherited}" />
<logger name="slick.compiler.ExpandConditionals" level="${log.qcomp.expandConditionals:-inherited}" /> <logger name="slick.compiler.ExpandConditionals" level="${log.qcomp.expandConditionals:-inherited}" />
<logger name="slick.compiler.FlattenProjections" level="${log.qcomp.flattenProjections:-inherited}" /> <logger name="slick.compiler.FlattenProjections" level="${log.qcomp.flattenProjections:-inherited}" />
<logger name="slick.compiler.RewriteJoins" level="${log.qcomp.rewriteJoins:-inherited}" />
<logger name="slick.compiler.RelabelUnions" level="${log.qcomp.relabelUnions:-inherited}" /> <logger name="slick.compiler.RelabelUnions" level="${log.qcomp.relabelUnions:-inherited}" />
<logger name="slick.compiler.PruneFields" level="${log.qcomp.pruneFields:-inherited}" /> <logger name="slick.compiler.PruneFields" level="${log.qcomp.pruneFields:-inherited}" />
<logger name="slick.compiler.ResolveZipJoins" level="${log.qcomp.resolveZipJoins:-inherited}" /> <logger name="slick.compiler.ResolveZipJoins" level="${log.qcomp.resolveZipJoins:-inherited}" />
Expand Down
Expand Up @@ -175,12 +175,12 @@ class AggregateTest extends AsyncTest[RelationalTestDB] {
def * = id def * = id
} }
val as = TableQuery[A] val as = TableQuery[A]
for { val q1 = as.groupBy(_.id).map { case (_, q) => (q.map(_.id).min, q.length) }
_ <- as.schema.create DBIO.seq(
_ <- as += 1 as.schema.create,
q1 = as.groupBy(_.id).map { case (_, q) => (q.map(_.id).min, q.length) } as += 1,
_ <- q1.result q1.result
} yield () )
} }


def testGroup3 = { def testGroup3 = {
Expand Down
Expand Up @@ -69,27 +69,24 @@ class CountTest extends AsyncTest[RelationalTestDB] {
def * = (aId, data) def * = (aId, data)
} }
lazy val bs = TableQuery[B] lazy val bs = TableQuery[B]
for { DBIO.seq(
_ <- (as.schema ++ bs.schema).create (as.schema ++ bs.schema).create,
_ <- as ++= Seq(1L, 2L) as ++= Seq(1L, 2L),
_ <- bs ++= Seq((1L, "1a"), (1L, "1b"), (2L, "2")) bs ++= Seq((1L, "1a"), (1L, "1b"), (2L, "2")),
qDirectLength = for { (for {
a <- as if a.id === 1L a <- as if a.id === 1L
} yield (a, (for { } yield (a, (for {
b <- bs if b.aId === a.id b <- bs if b.aId === a.id
} yield b).length) } yield b).length)).result.named("directLength").map(_ shouldBe Seq((1L, 2))),
_ <- qDirectLength.result.map(_ shouldBe Seq((1L, 2))) (for {
qJoinLength = for {
a <- as if a.id === 1L a <- as if a.id === 1L
l <- Query((for { l <- Query((for {
b <- bs if b.aId === a.id b <- bs if b.aId === a.id
} yield b).length) } yield b).length)
} yield (a, l) } yield (a, l)).result.named("joinLength").map(_ shouldBe Seq((1L, 2))),
_ <- qJoinLength.result.map(_ shouldBe Seq((1L, 2))) (for {
qOuterJoinLength = (for {
(a, b) <- as joinLeft bs on (_.id === _.aId) (a, b) <- as joinLeft bs on (_.id === _.aId)
} yield (a.id, b.map(_.data))).length } yield (a.id, b.map(_.data))).length.result.named("outerJoinLength").map(_ shouldBe 3)
_ <- qOuterJoinLength.result.map(_ shouldBe 3) )
} yield ()
} }
} }
Expand Up @@ -304,16 +304,16 @@ class JoinTest extends AsyncTest[RelationalTestDB] {
} }
lazy val cs = TableQuery[C] lazy val cs = TableQuery[C]


val q1 = for { def q1 = for {
(a, b) <- as joinLeft bs on (_.id === _.foreignId) (a, b) <- as joinLeft bs on (_.id === _.foreignId)
} yield (a, b) } yield (a, b)


val q2 = for { def q2 = for {
(a, b) <- q1 (a, b) <- q1
c <- cs if c.foreignId === a.id c <- cs if c.foreignId === a.id
} yield (a, c) } yield (a, c)


val q3 = for { def q3 = for {
(a, b) <- as joinLeft bs on (_.id === _.foreignId) (a, b) <- as joinLeft bs on (_.id === _.foreignId)
c <- cs if c.foreignId === a.id c <- cs if c.foreignId === a.id
} yield (a, c) } yield (a, c)
Expand All @@ -323,9 +323,9 @@ class JoinTest extends AsyncTest[RelationalTestDB] {
as ++= Seq(1,2,3), as ++= Seq(1,2,3),
bs ++= Seq(1,2,4,5), bs ++= Seq(1,2,4,5),
cs ++= Seq(1,2,4,6), cs ++= Seq(1,2,4,6),
q1.result.map(_.toSet shouldBe Set((1, Some(1)), (2, Some(2)), (3, None))), q1.result.named("q1").map(_.toSet shouldBe Set((1, Some(1)), (2, Some(2)), (3, None))),
q2.result.map(_.toSet shouldBe Set((1,1), (2,2))), q2.result.named("q2").map(_.toSet shouldBe Set((1,1), (2,2))),
q3.result.map(_.toSet shouldBe Set((1,1), (2,2))) q3.result.named("q3").map(_.toSet shouldBe Set((1,1), (2,2)))
) )
} }
} }
Expand Up @@ -104,13 +104,13 @@ class NestingTest extends AsyncTest[RelationalTestDB] {
val q5t: Query[(Rep[Option[Int]], Rep[Option[String]]), _, Seq] = q5 val q5t: Query[(Rep[Option[Int]], Rep[Option[String]]), _, Seq] = q5


val t1 = seq( val t1 = seq(
q1.result.map(_ shouldBe r.map(t => Some(t))), q1.result.named("q1").map(_ shouldBe r.map(t => Some(t))),
q1a2.result.map(_ shouldBe r.map(t => Some(Some(t)))), q1a2.result.named("q1a2").map(_ shouldBe r.map(t => Some(Some(t)))),
q2.result.map(_ shouldBe r.map(t => Some(t._1))), q2.result.named("q2").map(_ shouldBe r.map(t => Some(t._1))),
q2a2.result.map(_ shouldBe r.map(t => Some(Some(t._1)))), q2a2.result.named("q2a2").map(_ shouldBe r.map(t => Some(Some(t._1)))),
q3.result.map(_ shouldBe r.map(t => t._3)), q3.result.named("q3").map(_ shouldBe r.map(t => t._3)),
q4.result.map(_ shouldBe r.map(t => Some(t._3))), q4.result.named("q4").map(_ shouldBe r.map(t => Some(t._3))),
q5.result.map(_ shouldBe r.map(t => (t._3, Some(t._2)))) q5.result.named("q5").map(_ shouldBe r.map(t => (t._3, Some(t._2))))
) )


// Get plain values out // Get plain values out
Expand All @@ -124,10 +124,10 @@ class NestingTest extends AsyncTest[RelationalTestDB] {
val q4bt: Query[Rep[Option[Int]], _, Seq] = q4b val q4bt: Query[Rep[Option[Int]], _, Seq] = q4b


val t2 = seq( val t2 = seq(
q1b.result.map(_ shouldBe r.map(t => Some(t)).map(_.getOrElse((0, "", None: Option[String])))), q1b.result.named("q1b").map(_ shouldBe r.map(t => Some(t)).map(_.getOrElse((0, "", None: Option[String])))),
q2b.result.map(_ shouldBe r.map(t => Some(t._1)).map(_.get)), q2b.result.named("q2b").map(_ shouldBe r.map(t => Some(t._1)).map(_.get)),
q3b.result.map(_ shouldBe r.map(t => t._3).filter(_.isDefined).map(_.get)), q3b.result.named("q3b").map(_ shouldBe r.map(t => t._3).filter(_.isDefined).map(_.get)),
q4b.result.map(_ shouldBe r.map(t => Some(t._3)).map(_.getOrElse(None: Option[String]))) q4b.result.named("q4b").map(_ shouldBe r.map(t => Some(t._3)).map(_.getOrElse(None: Option[String])))
) )


// Unpack result types // Unpack result types
Expand All @@ -142,8 +142,8 @@ class NestingTest extends AsyncTest[RelationalTestDB] {
val q3c = q3.map(so => so + 10) val q3c = q3.map(so => so + 10)


val t3 = seq( val t3 = seq(
q2c.result.map(_ shouldBe r.map(t => Some(t._1)).map(_.map(_ + 42))), q2c.result.named("q2c").map(_ shouldBe r.map(t => Some(t._1)).map(_.map(_ + 42))),
q3c.result.map(_ shouldBe r.map(t => t._3).map(_.map(_ + 10))) q3c.result.named("q3c").map(_ shouldBe r.map(t => t._3).map(_.map(_ + 10)))
) )


// Use Option.map // Use Option.map
Expand All @@ -163,11 +163,11 @@ class NestingTest extends AsyncTest[RelationalTestDB] {
val q4dt: Query[Rep[Option[Int]], _, Seq] = q4d val q4dt: Query[Rep[Option[Int]], _, Seq] = q4d


val t4 = seq( val t4 = seq(
q1d.result.map(_ shouldBe r.map(t => Some(t)).map(_.map(_._1))), q1d.result.named("q1d").map(_ shouldBe r.map(t => Some(t)).map(_.map(_._1))),
q1d2.result.map(_ shouldBe r.map(t => Some(t)).map(_.map(x => (x._1, x._2, x._3)))), q1d2.result.named("q1d2").map(_ shouldBe r.map(t => Some(t)).map(_.map(x => (x._1, x._2, x._3)))),
q2d.result.map(_ shouldBe r.map(t => Some(t._1)).map(_.map(_ + 1))), q2d.result.named("q2d").map(_ shouldBe r.map(t => Some(t._1)).map(_.map(_ + 1))),
q3d.result.map(_ shouldBe r.map(t => t._3).map(_.map(s => (s, s, 1)))), q3d.result.named("q3d").map(_ shouldBe r.map(t => t._3).map(_.map(s => (s, s, 1)))),
q4d.result.map(_ shouldBe r.map(t => Some(t._3)).map(_.filter(_.isDefined).map(_.get))) q4d.result.named("q4d").map(_ shouldBe r.map(t => Some(t._3)).map(_.filter(_.isDefined).map(_.get)))
) )


// Use Option.flatMap // Use Option.flatMap
Expand All @@ -180,10 +180,10 @@ class NestingTest extends AsyncTest[RelationalTestDB] {
val q2et: Query[Rep[Option[Int]], _, Seq] = q2e val q2et: Query[Rep[Option[Int]], _, Seq] = q2e


val t5 = seq( val t5 = seq(
q1e1.result.map(_ shouldBe r.map(t => Some(t)).map { to => to.flatMap { t => Some(t._2) }}), q1e1.result.named("q1e1").map(_ shouldBe r.map(t => Some(t)).map { to => to.flatMap { t => Some(t._2) }}),
q1e2.result.map(_ shouldBe r.map(t => Some(t)).map { to => to.flatMap { t => t._3 }}), q1e2.result.named("q1e2").map(_ shouldBe r.map(t => Some(t)).map { to => to.flatMap { t => t._3 }}),
q1e3.result.map(_ shouldBe r.map(t => Some(t)).map(to => Some(to)).map(_.flatMap(identity))), q1e3.result.named("q1e3").map(_ shouldBe r.map(t => Some(t)).map(to => Some(to)).map(_.flatMap(identity))),
q2e.result.map(_ shouldBe r.map(t => Some(t._1)).map { io => io.flatMap { i => Some(i) }}) q2e.result.named("q2e").map(_ shouldBe r.map(t => Some(t._1)).map { io => io.flatMap { i => Some(i) }})
) )


// Use Option.flatten // Use Option.flatten
Expand All @@ -201,12 +201,12 @@ class NestingTest extends AsyncTest[RelationalTestDB] {
val q2f3t: Query[Rep[Option[Int]], _, Seq] = q2f3 val q2f3t: Query[Rep[Option[Int]], _, Seq] = q2f3


val t6 = seq( val t6 = seq(
q1f1.result.map(_ shouldBe Vector(Some(Some((1,"1",Some(1)))), Some(Some((2,"2",Some(2)))), Some(Some((3,"3",None))))), q1f1.result.named("q1f1").map(_ shouldBe Vector(Some(Some((1,"1",Some(1)))), Some(Some((2,"2",Some(2)))), Some(Some((3,"3",None))))),
q1f2.result.map(_ shouldBe r.map(t => Some(t)).map { to => Some(to).flatten }), q1f2.result.named("q1f2").map(_ shouldBe r.map(t => Some(t)).map { to => Some(to).flatten }),
q1f3.result.map(_ shouldBe r.map(t => Some(t)).map { to => Some(to) }.map(_.flatten)), q1f3.result.named("q1f3").map(_ shouldBe r.map(t => Some(t)).map { to => Some(to) }.map(_.flatten)),
q2f1.result.map(_ shouldBe r.map(t => Some(t._1)).map { io => Some(io) }), q2f1.result.named("q2f1").map(_ shouldBe r.map(t => Some(t._1)).map { io => Some(io) }),
q2f2.result.map(_ shouldBe r.map(t => Some(t._1)).map { io => Some(io).flatten }), q2f2.result.named("q2f2").map(_ shouldBe r.map(t => Some(t._1)).map { io => Some(io).flatten }),
q2f3.result.map(_ shouldBe r.map(t => Some(t._1)).map { io => Some(io) }.map(_.flatten)) q2f3.result.named("q2f3").map(_ shouldBe r.map(t => Some(t._1)).map { io => Some(io) }.map(_.flatten))
) )


setup >> t1 >> t2 >> t3 >> t4 >> t5 >> t6 setup >> t1 >> t2 >> t3 >> t4 >> t5 >> t6
Expand Down
Expand Up @@ -112,7 +112,7 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
(c2, s2) <- q1b_0 (c2, s2) <- q1b_0
} yield (c.name, s.city, c2.name) } yield (c.name, s.city, c2.name)


val a2 = seq( def a2 = seq(
q0.result.named("Plain table").map(_.toSet).map { r0 => q0.result.named("Plain table").map(_.toSet).map { r0 =>
r0 shouldBe Set( r0 shouldBe Set(
("Colombian", 101, 799, 1, 0), ("Colombian", 101, 799, 1, 0),
Expand Down Expand Up @@ -181,7 +181,7 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
d <- q4b_0 d <- q4b_0
} yield (c,d) } yield (c,d)


val a3 = seq( def a3 = seq(
q2.result.named("More elaborate query").map(_.toSet).map { r2 => q2.result.named("More elaborate query").map(_.toSet).map { r2 =>
r2 shouldBe Set( r2 shouldBe Set(
("Colombian","Acme, Inc."), ("Colombian","Acme, Inc."),
Expand Down Expand Up @@ -229,7 +229,7 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
// Unused outer query result, unbound TableQuery // Unused outer query result, unbound TableQuery
val q6 = coffees.flatMap(c => suppliers) val q6 = coffees.flatMap(c => suppliers)


val a4 = seq( def a4 = seq(
q5.result.map(_.toSet).map { r5 => q5.result.map(_.toSet).map { r5 =>
r5 shouldBe Set( r5 shouldBe Set(
(("Colombian",101,799,1,0),("Colombian",101,799,1,0)), (("Colombian",101,799,1,0),("Colombian",101,799,1,0)),
Expand Down Expand Up @@ -268,7 +268,7 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
c <- coffees.filter(_.price < 800).map((_, 1)) c <- coffees.filter(_.price < 800).map((_, 1))
} yield (c._1.name, c._1.supID, c._2) } yield (c._1.name, c._1.supID, c._2)


val a5 = seq( def a5 = seq(
q7a.result.named("Simple union").map(_.toSet).map { r7a => q7a.result.named("Simple union").map(_.toSet).map { r7a =>
r7a shouldBe Set( r7a shouldBe Set(
("Colombian",101,0), ("Colombian",101,0),
Expand Down Expand Up @@ -306,7 +306,7 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
t <- coffees.sortBy(_.sales).take(1) joinLeft coffees.sortBy(_.sales).take(2) on (_.name === _.name) joinLeft coffees.sortBy(_.sales).take(4) on (_._1.supID === _.supID) t <- coffees.sortBy(_.sales).take(1) joinLeft coffees.sortBy(_.sales).take(2) on (_.name === _.name) joinLeft coffees.sortBy(_.sales).take(4) on (_._1.supID === _.supID)
} yield (t._1, t._2) } yield (t._1, t._2)


val a6 = seq( def a6 = seq(
q7b.result.named("Union with filter on the outside").map(_.toSet).map { r7b => q7b.result.named("Union with filter on the outside").map(_.toSet).map { r7b =>
r7b shouldBe Set( r7b shouldBe Set(
("French_Roast",49,1), ("French_Roast",49,1),
Expand Down
Expand Up @@ -3,8 +3,8 @@ package slick.benchmark
import slick.ast._ import slick.ast._
import slick.jdbc._ import slick.jdbc._
import slick.relational._ import slick.relational._
import slick.util.TreePrinter
import com.typesafe.slick.testkit.util.DelegateResultSet import com.typesafe.slick.testkit.util.DelegateResultSet
import slick.util.TreePrinter


@deprecated("Using deprecated .simple API", "3.0") @deprecated("Using deprecated .simple API", "3.0")
object UnboxedBenchmark extends App { object UnboxedBenchmark extends App {
Expand Down
7 changes: 6 additions & 1 deletion slick/src/main/scala/slick/ast/Node.scala
Expand Up @@ -84,6 +84,8 @@ trait Node extends Dumpable {
else nodeTyped(tpe) else nodeTyped(tpe)
} }


final def :@ (tpe: Type): Self = nodeTypedOrCopy(tpe)

def nodeBuildTypedNode[T >: this.type <: Node](newNode: T, newType: Type): T = def nodeBuildTypedNode[T >: this.type <: Node](newNode: T, newType: Type): T =
if(newNode ne this) newNode.nodeTyped(newType) if(newNode ne this) newNode.nodeTyped(newType)
else if(newType == _nodeType) this else if(newType == _nodeType) this
Expand Down Expand Up @@ -514,7 +516,8 @@ final case class Ref(sym: Symbol) extends NullaryNode {
def nodeRebuild = copy() def nodeRebuild = copy()
} }


/** A constructor/extractor for nested Selects starting at a Ref. */ /** A constructor/extractor for nested Selects starting at a Ref so that, for example,
* `c :: b :: a :: Nil` corresponds to path `a.b.c`. */
object Path { object Path {
def apply(l: List[Symbol]): Node = l match { def apply(l: List[Symbol]): Node = l match {
case s :: Nil => Ref(s) case s :: Nil => Ref(s)
Expand All @@ -532,6 +535,8 @@ object Path {
} }
} }


/** A constructor/extractor for nested Selects starting at a Ref so that, for example,
* `a :: b :: c :: Nil` corresponds to path `a.b.c`. */
object FwdPath { object FwdPath {
def apply(ch: List[Symbol]) = Path(ch.reverse) def apply(ch: List[Symbol]) = Path(ch.reverse)
def unapply(n: Node): Option[List[Symbol]] = Path.unapply(n).map(_.reverse) def unapply(n: Node): Option[List[Symbol]] = Path.unapply(n).map(_.reverse)
Expand Down
23 changes: 20 additions & 3 deletions slick/src/main/scala/slick/ast/Util.scala
Expand Up @@ -36,7 +36,8 @@ object Scope {
final class NodeOps(val tree: Node) extends AnyVal { final class NodeOps(val tree: Node) extends AnyVal {
import Util._ import Util._


@inline def collect[T](pf: PartialFunction[Node, T]): Seq[T] = NodeOps.collect(tree, pf) @inline def collect[T](pf: PartialFunction[Node, T], stopOnMatch: Boolean = false): Seq[T] =
NodeOps.collect(tree, pf, stopOnMatch)


def collectAll[T](pf: PartialFunction[Node, Seq[T]]): Seq[T] = collect[Seq[T]](pf).flatten def collectAll[T](pf: PartialFunction[Node, Seq[T]]): Seq[T] = collect[Seq[T]](pf).flatten


Expand Down Expand Up @@ -75,6 +76,16 @@ final class NodeOps(val tree: Node) extends AnyVal {
case (s: ElementSymbol, ProductNode(ch)) => ch(s.idx-1) case (s: ElementSymbol, ProductNode(ch)) => ch(s.idx-1)
case (s, n) => Select(n, s) case (s, n) => Select(n, s)
} }

def hasRefTo(s: Symbol): Boolean = findNode {
case Ref(s2) if s2 == s => true
case _ => false
}.isDefined

def hasRefToOneOf(s: Set[Symbol]): Boolean = findNode {
case Ref(s2) if s contains s2 => true
case _ => false
}.isDefined
} }


object NodeOps { object NodeOps {
Expand All @@ -83,9 +94,15 @@ object NodeOps {
// These methods should be in the class but 2.10.0-RC1 took away the ability // These methods should be in the class but 2.10.0-RC1 took away the ability
// to use closures in value classes // to use closures in value classes


def collect[T](tree: Node, pf: PartialFunction[Node, T]): Seq[T] = { def collect[T](tree: Node, pf: PartialFunction[Node, T], stopOnMatch: Boolean): Seq[T] = {
val b = new ArrayBuffer[T] val b = new ArrayBuffer[T]
tree.foreach(pf.andThen[Unit]{ case t => b += t }.orElse[Node, Unit]{ case _ => () }) def f(n: Node): Unit = pf.andThen[Unit] { case t =>
b += t
if(!stopOnMatch) n.nodeChildren.foreach(f)
}.orElse[Node, Unit]{ case _ =>
n.nodeChildren.foreach(f)
}.apply(n)
f(tree)
b b
} }


Expand Down

0 comments on commit 2b14139

Please sign in to comment.