Permalink
Browse files

Improve Take, Drop and zip join translation

- Add a new optional phase `removeTakeDrop` which can translate all
  `Take` and `Drop` nodes to `zipWithIndex` operations.

- Change `resolveZipJoins` to translate zip joins of two queries into
  inner joins on top of `zipWithIndex` operations. Only `zipWithIndex`
  remains as a primitive operation.

- Add a `Subquery` node which prevents merging into an existing
  Comprehension in `mergeToComprehensions`, thus forcing a subquery to
  be created (unless it occurs at the top level). It is used by
  `resolveZipJoins` to isolate the `RowNum` generators.

- Upgrade to H2 1.4.187 because of an operation reordering bug with
  advanced usages of ROWNUM in 1.3.170.

- Fix a typing bug in `rewriteJoins`

- Remove the `OracleStyleRowNum` mix-in for QueryBuilder. It is no
  longer needed with the early translation of zipWithIndex operations.

- Implement a custom `resolveZipJoins` for MySQL to replace the
  translations for supporting RowNumber in the code generator.

Subquery conditions are not yet supported in this version. They are
always treated as `Subquery.All` which is the conservative choice but
will lead to unnecessary subqueries in some cases.
  • Loading branch information...
szeiger committed Jun 24, 2015
1 parent f170783 commit 36d5552a9992a1e1aeea73235c641eda87d8f44f
@@ -26,6 +26,7 @@
<logger name="slick.compiler.FlattenProjections" level="${log.qcomp.flattenProjections:-inherited}" />
<logger name="slick.compiler.CreateAggregates" level="${log.qcomp.createAggregates:-inherited}" />
<logger name="slick.compiler.RewriteJoins" level="${log.qcomp.rewriteJoins:-inherited}" />
<logger name="slick.compiler.RemoveTakeDrop" level="${log.qcomp.removeTakeDrop:-inherited}" />
<logger name="slick.compiler.ResolveZipJoins" level="${log.qcomp.resolveZipJoins:-inherited}" />
<logger name="slick.compiler.MergeToComprehensions" level="${log.qcomp.mergeToComprehensions:-inherited}" />
<logger name="slick.compiler.FixRowNumberOrdering" level="${log.qcomp.fixRowNumberOrdering:-inherited}" />
View
@@ -36,7 +36,7 @@ object SlickBuild extends Build {
"com.zaxxer" % "HikariCP-java6" % "2.0.1"
)
val mainDependencies = Seq(slf4j, typesafeConfig, reactiveStreams) ++ pools.map(_ % "optional")
val h2 = "com.h2database" % "h2" % "1.3.170"
val h2 = "com.h2database" % "h2" % "1.4.187"
val testDBs = Seq(
h2,
"org.xerial" % "sqlite-jdbc" % "3.8.7",
@@ -26,6 +26,7 @@
<logger name="slick.compiler.FlattenProjections" level="${log.qcomp.flattenProjections:-inherited}" />
<logger name="slick.compiler.CreateAggregates" level="${log.qcomp.createAggregates:-inherited}" />
<logger name="slick.compiler.RewriteJoins" level="${log.qcomp.rewriteJoins:-inherited}" />
<logger name="slick.compiler.RemoveTakeDrop" level="${log.qcomp.removeTakeDrop:-inherited}" />
<logger name="slick.compiler.ResolveZipJoins" level="${log.qcomp.resolveZipJoins:-inherited}" />
<logger name="slick.compiler.MergeToComprehensions" level="${log.qcomp.mergeToComprehensions:-inherited}" />
<logger name="slick.compiler.FixRowNumberOrdering" level="${log.qcomp.fixRowNumberOrdering:-inherited}" />
@@ -180,44 +180,40 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
} yield (c._1, c._3)
// Map to tuple, then filter, with self-join
val q4b_0 = coffees.map(c => (c.name, c.price, 42)).filter(_._2 < 800)
val q4b = for {
def q4b_0 = coffees.map(c => (c.name, c.price, 42)).filter(_._2 < 800)
def q4b = for {
c <- q4b_0
d <- q4b_0
} yield (c,d)
def a3 = seq(
q2.result.named("More elaborate query").map(_.toSet).map { r2 =>
def a3 = for {
_ <- q2.result.named("More elaborate query").map(_.toSet).map { r2 =>
r2 shouldBe Set(
("Colombian","Acme, Inc."),
("French_Roast","Superior Coffee"),
("Colombian_Decaf","Acme, Inc.")
)
},
q3.result.named("Lifting scalar values").map(_.toSet).map { r3 =>
}
_ <- q3.result.named("Lifting scalar values").map(_.toSet).map { r3 =>
r3 shouldBe Set(("Colombian_Decaf","Acme, Inc.","Colombian_Decaf",0,3396))
},
q3b.result.named("Lifting scalar values, with extra tuple").map(_.toSet).map { r3b =>
}
_ <- q3b.result.named("Lifting scalar values, with extra tuple").map(_.toSet).map { r3b =>
r3b shouldBe Set(
("Colombian","Acme, Inc.","Colombian",0,799,42),
("French_Roast","Superior Coffee","French_Roast",0,1598,42),
("Colombian_Decaf","Acme, Inc.","Colombian_Decaf",0,3396,42)
)
},
ifCap(rcap.pagingNested) {
q4.result.named("Map to tuple, then filter").map { r4 =>
r4.toSet shouldBe Set(("Colombian",42))
}
},
q4b.result.map(_.toSet).map { r4b =>
r4b shouldBe Set(
(("Colombian",799,42),("Colombian",799,42)),
(("Colombian",799,42),("French_Roast",799,42)),
(("French_Roast",799,42),("Colombian",799,42)),
(("French_Roast",799,42),("French_Roast",799,42))
)
}
)
_ <- ifCap(rcap.pagingNested) {
mark("q4", q4.result).named("q4: Map to tuple, then filter").map(_.toSet shouldBe Set(("Colombian",42)))
}
_ <- mark("q4b", q4b.result).map(_.toSet shouldBe Set(
(("Colombian",799,42),("Colombian",799,42)),
(("Colombian",799,42),("French_Roast",799,42)),
(("French_Roast",799,42),("Colombian",799,42)),
(("French_Roast",799,42),("French_Roast",799,42))
))
} yield ()
// Implicit self-join
val q5_0 = coffees.sortBy(_.price).take(2)
@@ -16,6 +16,7 @@ class PagingTest extends AsyncTest[RelationalTestDB] {
val q2 = q1 take 5
def q3 = q1 drop 5
def q4 = q1 drop 5 take 3
val q4b = q1.drop(5).take(3).sortBy(_.id)
def q5 = q1 take 5 drop 3
val q6 = q1 take 0
@@ -27,6 +28,7 @@ class PagingTest extends AsyncTest[RelationalTestDB] {
_ <- ifCap(rcap.pagingDrop)(for {
_ <- mark("q3", q3.result).map(_ shouldBe (6 to 10).toList)
_ <- mark("q4", q4.result).map(_ shouldBe (6 to 8).toList)
_ <- mark("q4b", q4b.result).map(_ shouldBe (6 to 8).toList)
_ <- mark("q5", q5.result).map(_ shouldBe (4 to 5).toList)
} yield ())
_ <- mark("q6", q6.result).map(_ shouldBe Nil)
@@ -88,11 +88,11 @@ class TypedStaticQueryTest {
},
s3.map { o3 =>
val t3: Vector[Foo] = o3.map(Foo(_))
assertEquals(Vector(Foo(101), Foo(150), Foo(49)), t3)
assertEquals(Set(Foo(101), Foo(150), Foo(49)), t3.toSet)
},
s4.map { o4 =>
val t4: Vector[Bar] = o4.map(Bar(_))
assertEquals(List(Bar("Groundsville"), Bar("Meadows"), Bar("Mendocino")), t4)
assertEquals(Set(Bar("Groundsville"), Bar("Meadows"), Bar("Mendocino")), t4.toSet)
}
)), Duration.Inf)
} finally dc.db.close()
@@ -209,6 +209,23 @@ final case class CollectionCast(child: Node, cons: CollectionTypeConstructor) ex
def nodeMapServerSide(keepType: Boolean, r: Node => Node) = mapChildren(r, keepType)
}
/** Forces a subquery to be created in `mergeToComprehension` if it occurs between two other
* collection-valued operations that would otherwise be fused, and the subquery condition
* is true. */
final case class Subquery(child: Node, condition: Subquery.Condition) extends UnaryNode with SimplyTypedNode {
type Self = Subquery
protected[this] def rebuild(child: Node) = copy(child = child)
protected def buildType = child.nodeType
}
object Subquery {
sealed trait Condition
/** Always create a subquery */
case object Always extends Condition
/** Create a subquery if the current Comprehension contains a GROUP BY, ORDER BY or HAVING clause */
case object AboveGroupBy extends Condition
}
/** Common superclass for expressions of type (CollectionType(c, t), _) => CollectionType(c, t). */
abstract class FilteredQuery extends Node {
protected[this] def generator: TermSymbol
@@ -27,6 +27,7 @@ object Util {
/** Extra methods for Nodes. */
final class NodeOps(val tree: Node) extends AnyVal {
import Util._
import NodeOps._
@inline def collect[T](pf: PartialFunction[Node, T], stopOnMatch: Boolean = false): Seq[T] = {
val b = new ArrayBuffer[T]
@@ -68,16 +69,19 @@ final class NodeOps(val tree: Node) extends AnyVal {
* to the invalidated TypeSymbols have their types unassigned, so that the whole tree can be
* retyped afterwards to get the correct new TypeSymbols in. */
def replaceInvalidate(f: PartialFunction[(Node, Set[TypeSymbol], Node), (Node, Set[TypeSymbol])]): Node = {
def containsTS(t: Type, invalid: Set[TypeSymbol]): Boolean = t match {
case NominalType(ts, exp) => invalid.contains(ts) || containsTS(exp, invalid)
case t => t.children.exists(ch => containsTS(ch, invalid))
}
replaceFold(Set.empty[TypeSymbol])(f.orElse {
case ((n: Ref), invalid, _) if containsTS(n.nodeType, invalid) => (n.untyped, invalid)
case ((n: Select), invalid, _) if containsTS(n.nodeType, invalid) => (n.untyped, invalid)
})._1
}
def untypeReferences(invalid: Set[TypeSymbol]): Node = {
if(invalid.isEmpty) tree else replace({
case n: Ref if containsTS(n.nodeType, invalid) => n.untyped
case n: Select if containsTS(n.nodeType, invalid) => n.untyped
}, bottomUp = true)
}
def foreach[U](f: (Node => U)): Unit = {
def g(n: Node) {
f(n)
@@ -101,3 +105,12 @@ final class NodeOps(val tree: Node) extends AnyVal {
case (s, n) => Select(n, s)
}
}
private object NodeOps {
private def containsTS(t: Type, invalid: Set[TypeSymbol]): Boolean = {
if(invalid.isEmpty) false else t match {
case NominalType(ts, exp) => invalid.contains(ts) || containsTS(exp, invalid)
case t => t.children.exists(ch => containsTS(ch, invalid))
}
}
}
@@ -9,7 +9,7 @@ class FixRowNumberOrdering extends Phase {
val name = "fixRowNumberOrdering"
def apply(state: CompilerState) =
if(state.get(Phase.resolveZipJoins).get) state.map(n => fix(n)) else state
if(state.get(Phase.resolveZipJoins).getOrElse(false)) state.map(n => fix(n)) else state
/** Push ORDER BY into RowNumbers in ordered Comprehensions. */
def fix(n: Node, parent: Option[Comprehension] = None): Node = (n, parent) match {
@@ -232,6 +232,9 @@ class MergeToComprehensions extends Phase {
logger.debug("Converted Union:", u2)
(u2, rep1)
case Subquery(n, _) =>
createTopLevel(n)
case n =>
val (c, rep) = mergeTakeDrop(n, false)
val mappings = rep.mapValues(_ :: Nil).toSeq
@@ -104,6 +104,7 @@ object QueryCompiler {
Phase.removeMappedTypes,
/* Convert to column form */
Phase.expandSums,
// optional removeTakeDrop goes here
// optional emulateOuterJoins goes here
Phase.expandConditionals,
Phase.expandRecords,
@@ -172,6 +173,7 @@ object Phase {
val createAggregates = new CreateAggregates
val rewriteJoins = new RewriteJoins
val verifySymbols = new VerifySymbols
val removeTakeDrop = new RemoveTakeDrop
val resolveZipJoins = new ResolveZipJoins
val relabelUnions = new RelabelUnions
val mergeToComprehensions = new MergeToComprehensions
@@ -181,6 +183,7 @@ object Phase {
val removeFieldNames = new RemoveFieldNames
/* Extra phases that are not enabled by default */
val resolveZipJoinsRownumStyle = new ResolveZipJoins(rownumStyle = true)
val rewriteBooleans = new RewriteBooleans
val specializeParameters = new SpecializeParameters
val verifyTypes = new VerifyTypes
@@ -0,0 +1,67 @@
package slick.compiler
import slick.ast._
import Util._
import TypeUtil._
import QueryParameter.constOp
import scala.collection.mutable
/** Replace all occurrences of `Take` and `Drop` with row number computations based on
* `zipWithIndex` operations. */
class RemoveTakeDrop extends Phase {
val name = "removeTakeDrop"
def apply(state: CompilerState) = state.map { n =>
val n2 = n.replaceInvalidate {
case (n @ TakeDrop(from, t, d), invalid, _) =>
logger.debug(s"""Translating "drop $d, then take $t" to zipWithIndex operation:""", n)
val fromRetyped = from.infer()
val from2 = fromRetyped match {
case b: Bind => b
case n =>
val s = new AnonSymbol
Bind(s, n, Pure(Ref(s)))
}
val j = Join(new AnonSymbol, new AnonSymbol, from2, RangeFrom(1L), JoinType.Zip, LiteralNode(true))
val bs1 = new AnonSymbol
val b1 = Bind(bs1, j, Pure(Ref(bs1)))
val fs = new AnonSymbol
val f = Filter(fs, b1, (t, d) match {
case (None, Some(d)) => Library.>.typed[Boolean](Select(Ref(fs), ElementSymbol(2)), d)
case (Some(t), None) => Library.<=.typed[Boolean](Select(Ref(fs), ElementSymbol(2)), t)
case (Some(t), Some(d)) =>
Library.And.typed[Boolean](
Library.>.typed[Boolean](Select(Ref(fs), ElementSymbol(2)), d),
Library.<=.typed[Boolean](Select(Ref(fs), ElementSymbol(2)), constOp[Long]("+")(_ + _)(t, d))
)
})
val bs2 = new AnonSymbol
val b2 = Bind(bs2, f, Pure(Select(Ref(bs2), ElementSymbol(1))))
logger.debug(s"""Translated "drop $d, then take $t" to zipWithIndex operation:""", b2)
val invalidate = fromRetyped.nodeType.collect { case NominalType(ts, _) => ts }
logger.debug("Invalidating TypeSymbols: "+invalidate.mkString(", "))
(b2, invalid ++ invalidate)
}
logger.debug("After removeTakeDrop without inferring:", n2)
n2.infer()
}
/** An extractor for nested Take and Drop nodes */
object TakeDrop {
def unapply(n: Node): Option[(Node, Option[Node], Option[Node])] = n match {
case Take(from, num) => unapply(from) match {
case Some((f, Some(t), d)) => Some((f, Some(constOp[Long]("min")(math.min)(t, num)), d))
case Some((f, None, d)) => Some((f, Some(num), d))
case _ => Some((from, Some(num), None))
}
case Drop(from, num) => unapply(from) match {
case Some((f, Some(t), None)) => Some((f, Some(constOp[Long]("max")(math.max)(LiteralNode(0L).infer(), constOp[Long]("-")(_ - _)(t, num))), Some(num)))
case Some((f, None, Some(d))) => Some((f, None, Some(constOp[Long]("+")(_ + _)(d, num))))
case Some((f, Some(t), Some(d))) => Some((f, Some(constOp[Long]("max")(math.max)(LiteralNode(0L).infer(), constOp[Long]("-")(_ - _)(t, num))), Some(constOp[Long]("+")(_ + _)(d, num))))
case _ => Some((from, None, Some(num)))
}
case _ => None
}
}
}
Oops, something went wrong.

0 comments on commit 36d5552

Please sign in to comment.