Permalink
Browse files

fixing part 1 of #1408, flatmaps in the tail position of a query are …

…unrolled into nested bind clauses
  • Loading branch information...
Alexander Ioffe Alexander Ioffe
Alexander Ioffe authored and Alexander Ioffe committed Dec 27, 2016
1 parent 47d1db0 commit 3f5c6753b4108564a3583c8b49f0a89374a6dd4c
@@ -17,6 +17,7 @@ testkit {
# All TestkitTest classes to run
testPackage = com.typesafe.slick.testkit.tests
testClasses = [
${testPackage}.ThreeWayJoinTest
${testPackage}.ActionTest
${testPackage}.AggregateTest
${testPackage}.ColumnDefaultTest
@@ -0,0 +1,117 @@
package com.typesafe.slick.testkit.tests
import org.junit.Assert._
import com.typesafe.slick.testkit.util.{RelationalTestDB, AsyncTest}
class ThreeWayJoinTest extends AsyncTest[RelationalTestDB] {
import tdb.profile.api._
// ******************** Full many to many join test (i.e. four table) **********************
def testManyToManyJoin = {
class A(tag: Tag) extends Table[Int](tag, "a_manytomanyjoin") {
def id = column[Int]("id", O.PrimaryKey)
def * = id
def bs = cs.filter(_.aId === id).flatMap(_.b)
}
lazy val as = TableQuery[A]
class B(tag: Tag) extends Table[(Int, Int)](tag, "b_manytomanyjoin") {
def id = column[Int]("id", O.PrimaryKey)
def dId = column[Int]("dId")
def * = (id, dId)
def as = cs.filter(_.bId === id).flatMap(_.a)
def d = foreignKey("d_fk", dId, ds)(_.id)
}
lazy val bs = TableQuery[B]
class C(tag: Tag) extends Table[(Int, Int)](tag, "c_manytomanyjoin") {
def aId = column[Int]("aId")
def bId = column[Int]("bId")
def * = (aId, bId)
def a = foreignKey("a_fk", aId, as)(_.id)
def b = foreignKey("b_fk", bId, bs)(_.id)
}
lazy val cs = TableQuery[C]
class D(tag: Tag) extends Table[Int](tag, "d_manytomanyjoin") {
def id = column[Int]("id", O.PrimaryKey)
def * = id
}
lazy val ds = TableQuery[D]
def q1 = for {
a <- as
b <- a.bs
d <- b.d
} yield (a, b.id, d)
DBIO.seq(
(as.schema ++ bs.schema ++ cs.schema ++ ds.schema).create,
as ++= Seq(1),
ds ++= Seq(3),
bs ++= Seq((2,3)),
cs ++= Seq((1,2)),
q1.result.named("q1").map(_.toSet shouldBe Set((1, 2, 3)))
)
}
// ******************** Many to many join across two tables **********************
def testManyToManyJoinTwice = {
class A(tag: Tag) extends Table[Int](tag, "a_manytomanyjoin2") {
def id = column[Int]("id", O.PrimaryKey)
def * = id
def bs = atbs.filter(_.aId === id).flatMap(_.b)
}
lazy val as = TableQuery[A]
class ATB(tag: Tag) extends Table[(Int, Int)](tag, "atb_manytomanyjoin2") {
def aId = column[Int]("aId")
def bId = column[Int]("bId")
def * = (aId, bId)
def a = foreignKey("a_fk2", aId, as)(_.id)
def b = foreignKey("b_fk2", bId, bs)(_.id)
}
lazy val atbs = TableQuery[ATB]
class B(tag: Tag) extends Table[Int](tag, "b_manytomanyjoin2") {
def id = column[Int]("id", O.PrimaryKey)
def * = id
def cs = btcs.filter(_.bId === id).flatMap(_.c)
}
lazy val bs = TableQuery[B]
class BTC(tag: Tag) extends Table[(Int, Int)](tag, "btc_manytomanyjoin2") {
def bId = column[Int]("bId")
def cId = column[Int]("cId")
def * = (bId, cId)
def b = foreignKey("b_fk3", bId, bs)(_.id)
def c = foreignKey("c_fk3", cId, cs)(_.id)
}
lazy val btcs = TableQuery[BTC]
class C(tag: Tag) extends Table[Int](tag, "c_manytomanyjoin2") {
def id = column[Int]("id", O.PrimaryKey)
def * = id
}
lazy val cs = TableQuery[C]
def q1 = for {
a <- as
b <- a.bs
c <- b.cs
} yield (a, b.id, c)
DBIO.seq(
(as.schema ++ atbs.schema ++ bs.schema ++ btcs.schema ++ cs.schema).create,
as ++= Seq(1),
bs ++= Seq(2),
cs ++= Seq(3),
atbs ++= Seq((1, 2)),
btcs ++= Seq((2, 3)),
q1.result.named("q1").map(_.toSet shouldBe Set((1, 2, 3)))
)
}
}
@@ -106,6 +106,7 @@ object QueryCompiler {
val standardPhases = Vector(
/* Clean up trees from the lifted embedding */
Phase.assignUniqueSymbols,
Phase.unrollTailBinds,
/* Distribute and normalize */
Phase.inferTypes,
Phase.expandTables,
@@ -171,6 +172,7 @@ trait Phase extends (CompilerState => CompilerState) with Logging {
* the standard phases of the query compiler */
object Phase {
/* The standard phases of the query compiler */
val unrollTailBinds = new UnrollTailBinds
val assignUniqueSymbols = new AssignUniqueSymbols
val inferTypes = new InferTypes
val expandTables = new ExpandTables
@@ -0,0 +1,48 @@
package slick.compiler
import slick.ast.Library.AggregateFunctionSymbol
import scala.collection.mutable.{HashSet, HashMap}
import slick.SlickException
import slick.ast._
import TypeUtil._
import Util._
class UnrollTailBinds extends Phase {
val name = "unrollTailBinds"
def apply(state: CompilerState) = state.map(tr(_))
def tr(n: Node): Node = {
n match {
case bb@Bind(br,
Bind(bo, Filter(fa, ff1, where1), Filter(fb, ff2, where2)),
Bind(bi, bf, select)) => {
// make a new symbol
val bm = new AnonSymbol
def rep(node: Node) = {
def repInternal(node: Node): Node = {
val out = node.replace({
case p@Path(brs :: tail) if brs == br => {
Path(List(bm) ++ tail)
}
}, bottomUp = true, keepType = true)
out.mapChildren(repInternal)
}
repInternal(node)
}
// bind all needed elements to the new symbol
val bindCombo = Bind(bo,
Filter(fa, rep(ff1), rep(where1)), //replace s4 here with s5
Bind(bm,
Filter(fb, rep(ff2), rep(where2)), // replace s4 here with s5
Bind(bi, rep(bf), rep(select)))) // replace s4 here with s5
bindCombo.mapChildren(tr)
}
case n => n.mapChildren(tr)
}
}
}

0 comments on commit 3f5c675

Please sign in to comment.