Skip to content

Commit

Permalink
Early hoisting of client-side operations:
Browse files Browse the repository at this point in the history
- Run `createResultSetMapping` and `hoistClientOps` before comprehension
  fusion. Hoisting only moves individual operations but does not
  eliminate entire Bind/Comprehension nodes. By running it early, we
  can hoist as much as possible and then naturally merge the remaining
  mappings in `mergeToComprehensions`.

- Rewrite `hoistClientOps` to work on unmerged trees, bubbling `Bind`
  operations that may be eligible for hoisting up to the top level.

- Simplify well-typedness checking by removing the option for verifying
  only the server side. With the new design the entire AST stays
  well-typed after `expandTables`.

- Merge `Filter` operations in the correct order (bottom to top) in
  `mergeToComprehensions`.

- Push `Filter` down into `Union` in `reorderOperations`.
  • Loading branch information
szeiger committed Aug 3, 2015
1 parent ffee22c commit aaf2a22
Show file tree
Hide file tree
Showing 13 changed files with 140 additions and 139 deletions.
2 changes: 1 addition & 1 deletion common-test-resources/logback.xml
Expand Up @@ -28,10 +28,10 @@
<logger name="slick.compiler.RewriteJoins" level="${log.qcomp.rewriteJoins:-inherited}" />
<logger name="slick.compiler.RemoveTakeDrop" level="${log.qcomp.removeTakeDrop:-inherited}" />
<logger name="slick.compiler.ResolveZipJoins" level="${log.qcomp.resolveZipJoins:-inherited}" />
<logger name="slick.compiler.HoistClientOps" level="${log.qcomp.hoistClientOps:-inherited}" />
<logger name="slick.compiler.ReorderOperations" level="${log.qcomp.reorderOperations:-inherited}" />
<logger name="slick.compiler.MergeToComprehensions" level="${log.qcomp.mergeToComprehensions:-inherited}" />
<logger name="slick.compiler.FixRowNumberOrdering" level="${log.qcomp.fixRowNumberOrdering:-inherited}" />
<logger name="slick.compiler.HoistClientOps" level="${log.qcomp.hoistClientOps:-inherited}" />
<logger name="slick.compiler.PruneProjections" level="${log.qcomp.pruneProjections:-inherited}" />
<logger name="slick.compiler.RewriteBooleans" level="${log.qcomp.rewriteBooleans:-inherited}" />
<logger name="slick.compiler.SpecializeParameters" level="${log.qcomp.specializeParameters:-inherited}" />
Expand Down
2 changes: 1 addition & 1 deletion slick-testkit/src/doctest/resources/logback.xml
Expand Up @@ -28,10 +28,10 @@
<logger name="slick.compiler.RewriteJoins" level="${log.qcomp.rewriteJoins:-inherited}" />
<logger name="slick.compiler.RemoveTakeDrop" level="${log.qcomp.removeTakeDrop:-inherited}" />
<logger name="slick.compiler.ResolveZipJoins" level="${log.qcomp.resolveZipJoins:-inherited}" />
<logger name="slick.compiler.HoistClientOps" level="${log.qcomp.hoistClientOps:-inherited}" />
<logger name="slick.compiler.ReorderOperations" level="${log.qcomp.reorderOperations:-inherited}" />
<logger name="slick.compiler.MergeToComprehensions" level="${log.qcomp.mergeToComprehensions:-inherited}" />
<logger name="slick.compiler.FixRowNumberOrdering" level="${log.qcomp.fixRowNumberOrdering:-inherited}" />
<logger name="slick.compiler.HoistClientOps" level="${log.qcomp.hoistClientOps:-inherited}" />
<logger name="slick.compiler.PruneProjections" level="${log.qcomp.pruneProjections:-inherited}" />
<logger name="slick.compiler.RewriteBooleans" level="${log.qcomp.rewriteBooleans:-inherited}" />
<logger name="slick.compiler.SpecializeParameters" level="${log.qcomp.specializeParameters:-inherited}" />
Expand Down
Expand Up @@ -501,6 +501,7 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
val q13 = (as.filter(_.id < 2) union as.filter(_.id > 2)).map(_.id)
val q14 = q13.to[Set]
val q15 = (as.map(a => a.id.?).filter(_ < 2) unionAll as.map(a => a.id.?).filter(_ > 2)).map(_.get).to[Set]
val q16 = (as.map(a => a.id.?).filter(_ < 2) unionAll as.map(a => a.id.?).filter(_ > 2)).map(_.getOrElse(-1)).to[Set].filter(_ =!= 42)

if(tdb.driver == H2Driver) {
assertNesting(q1, 1)
Expand All @@ -525,7 +526,8 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
assertNesting(q12, 2)
assertNesting(q13, 2)
assertNesting(q14, 2)
//assertNesting(q15, 2) //TODO
assertNesting(q15, 2)
assertNesting(q16, 2)
}

for {
Expand Down Expand Up @@ -555,6 +557,7 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
_ <- mark("q13", q13.result).map(_.toSet shouldBe Set(1, 3))
_ <- mark("q14", q14.result).map(_ shouldBe Set(1, 3))
_ <- mark("q15", q15.result).map(_ shouldBe Set(1, 3))
_ <- mark("q16", q16.result).map(_ shouldBe Set(1, 3))
} yield ()
}

Expand Down
2 changes: 1 addition & 1 deletion slick/src/main/scala/slick/ast/Node.scala
Expand Up @@ -202,7 +202,7 @@ final case class Pure(value: Node, identity: TypeSymbol = new AnonTypeSymbol) ex
protected def buildType = CollectionType(TypedCollectionTypeConstructor.seq, NominalType(identity, value.nodeType))
}

final case class CollectionCast(child: Node, cons: CollectionTypeConstructor) extends UnaryNode with SimplyTypedNode with ClientSideOp {
final case class CollectionCast(child: Node, cons: CollectionTypeConstructor) extends UnaryNode with SimplyTypedNode {
type Self = CollectionCast
protected[this] def rebuild(child: Node) = copy(child = child)
protected def buildType = CollectionType(cons, child.nodeType.asCollectionType.elementType)
Expand Down
8 changes: 1 addition & 7 deletions slick/src/main/scala/slick/compiler/CodeGen.scala
Expand Up @@ -17,7 +17,7 @@ abstract class CodeGen extends Phase {
var nmap: Option[Node] = None
var compileMap: Option[Node] = Some(rsm.map)

val nfrom = mapServerSideOrCast(rsm.from, keepType = true) { ss =>
val nfrom = ClientSideOp.mapServerSide(rsm.from, keepType = true) { ss =>
logger.debug("Compiling server-side and mapping with server-side:", ss)
val (nss, nmapOpt) = compileServerSideAndMapping(ss, compileMap, state)
nmapOpt match {
Expand All @@ -32,12 +32,6 @@ abstract class CodeGen extends Phase {
rsm.copy(from = nfrom, map = nmap.get) :@ rsm.nodeType
}

private[this] def mapServerSideOrCast(n: Node, keepType: Boolean = true)(f: Node => Node): Node = n match {
case n: CollectionCast => f(n)
case n: ClientSideOp => n.nodeMapServerSide(keepType, (ch => mapServerSideOrCast(ch, keepType)(f)))
case n => f(n)
}

def compileServerSideAndMapping(serverSide: Node, mapping: Option[Node], state: CompilerState): (Node, Option[Node])

/** Extract the source tree and type, after possible CollectionCast operations, from a tree */
Expand Down
24 changes: 11 additions & 13 deletions slick/src/main/scala/slick/compiler/CreateResultSetMapping.scala
Expand Up @@ -7,8 +7,7 @@ import TypeUtil._

/** Create a ResultSetMapping root node, ensure that the top-level server-side node returns a
* collection, and hoist client-side type conversions into the ResultSetMapping. The original
* result type (which was removed by `removeMappedTypes`) is assigned back to the top level,
* so the client side is no longer well-typed after this phase. */
* result type (which was removed by `removeMappedTypes`) is assigned back to the top level. */
class CreateResultSetMapping extends Phase {
val name = "createResultSetMapping"

Expand All @@ -23,12 +22,12 @@ class CreateResultSetMapping extends Phase {
val gen = new AnonSymbol
(tpe match {
case CollectionType(cons, el) =>
ResultSetMapping(gen, collectionCast(ch, cons), createResult(gen, el, syms))
ResultSetMapping(gen, collectionCast(ch, cons).infer(), createResult(Ref(gen) :@ ch.nodeType.asCollectionType.elementType, el, syms))
case t =>
ResultSetMapping(gen, ch, createResult(gen, t, syms))
ResultSetMapping(gen, ch, createResult(Ref(gen) :@ ch.nodeType.asCollectionType.elementType, t, syms))
})
}.infer()
}.withWellTyped(WellTyped.ServerSide)
}

def collectionCast(ch: Node, cons: CollectionTypeConstructor): Node = ch.nodeType match {
case CollectionType(c, _) if c == cons => ch
Expand All @@ -37,7 +36,7 @@ class CreateResultSetMapping extends Phase {

/** Create a structured return value for the client side, based on the
* result type (which may contain MappedTypes). */
def createResult(sym: TermSymbol, tpe: Type, syms: IndexedSeq[TermSymbol]): Node = {
def createResult(ref: Ref, tpe: Type, syms: IndexedSeq[TermSymbol]): Node = {
var curIdx = 0
def f(tpe: Type): Node = {
logger.debug("Creating mapping from "+tpe)
Expand All @@ -56,24 +55,23 @@ class CreateResultSetMapping extends Phase {
curIdx += 1
// Assign the original type. Inside a RebuildOption the actual column type will always be
// Option-lifted but we can still treat it as the base type when the discriminator matches.
Library.SilentCast.typed(t.structuralRec, Select(Ref(sym), syms(curIdx-1)))
val sel = Select(ref, syms(curIdx-1)).infer()
val tSel = t.structuralRec
if(sel.nodeType.structuralRec == tSel) sel else Library.SilentCast.typed(tSel, sel)
}
}
f(tpe)
}
}

/** Remove all mapped types from the tree and store the original top-level type as the phase state
* to be used later for building the ResultSetMapping. After this phase the entire AST should be
* well-typed until `createResultSetMapping`. */
* to be used later for building the ResultSetMapping. */
class RemoveMappedTypes extends Phase {
val name = "removeMappedTypes"
type State = Type

def apply(state: CompilerState) = {
val tpe = state.tree.nodeType
state.withNode(removeTypeMapping(state.tree)).withWellTyped(WellTyped.All) + (this -> tpe)
}
def apply(state: CompilerState) =
state.withNode(removeTypeMapping(state.tree)) + (this -> state.tree.nodeType)

/** Remove TypeMapping nodes and MappedTypes */
def removeTypeMapping(n: Node): Node = n match {
Expand Down
6 changes: 3 additions & 3 deletions slick/src/main/scala/slick/compiler/ExpandTables.scala
Expand Up @@ -5,8 +5,8 @@ import Util._
import TypeUtil._

/** Expand table-valued expressions in the result type to their star projection and compute the
* missing structural expansions of table types. After this phase the server side of the AST
* should always be well-typed. */
* missing structural expansions of table types. After this phase the AST should always be
* well-typed. */
class ExpandTables extends Phase {
val name = "expandTables"

Expand Down Expand Up @@ -40,7 +40,7 @@ class ExpandTables extends Phase {
.infer(Type.Scope(sym -> tree2.nodeType.asCollectionType.elementType), typeChildren = true)
Bind(sym, tree2, Pure(mapping)).infer()
}
}}.withWellTyped(WellTyped.ServerSide)
}}.withWellTyped(true)

/** Create an expression that copies a structured value, expanding tables in it. */
def createResult(expansions: Map[TableIdentitySymbol, (TermSymbol, Node)], path: Node, tpe: Type): Node = tpe match {
Expand Down
142 changes: 76 additions & 66 deletions slick/src/main/scala/slick/compiler/HoistClientOps.scala
@@ -1,74 +1,94 @@
package slick.compiler

import scala.util.control.NonFatal
import slick.{SlickTreeException, SlickException}
import slick.SlickException
import slick.ast._
import Util._
import TypeUtil._
import slick.ast.Util._
import slick.ast.TypeUtil._
import slick.util.{Ellipsis, ??}

import scala.util.control.NonFatal

/** Lift operations that are preferred to be performed on the client side
* out of sub-queries. */
/** Lift applicable operations at the top level to the client side. */
class HoistClientOps extends Phase {
val name = "hoistClientOps"

def apply(state: CompilerState) = state.map { tree =>
ClientSideOp.mapResultSetMapping(tree, false) { case rsm @ ResultSetMapping(_, ss, _) =>
val CollectionType(cons, NominalType(_, StructType(defs1))) = ss.nodeType
val base = new AnonSymbol
val proj = StructNode(defs1.map { case (s, _) => (s, Select(Ref(base), s)) })
val ResultSetMapping(_, rsmFrom, rsmProj) = hoist(ResultSetMapping(base, ss, proj))
logger.debug("Hoisted projection:", rsmProj)
logger.debug("Rewriting remaining DB side:", rsmFrom)
val rsm2 = ResultSetMapping(base, rewriteDBSide(rsmFrom), rsmProj).infer()
fuseResultSetMappings(rsm.copy(from = rsm2)).infer()
def apply(state: CompilerState) = state.map(ClientSideOp.mapResultSetMapping(_) { rsm =>
val from1 = shuffle(rsm.from)
from1 match {
case Bind(s2, from2, Pure(StructNode(defs2), ts2)) =>
// Extract client-side operations into ResultSetMapping
val hoisted = defs2.map { case (ts, n) => (ts, n, unwrap(n)) }
logger.debug("Hoisting operations from defs: " + hoisted.filter(t => t._2 ne t._3._1).map(_._1).mkString(", "))
val newDefsM = hoisted.map { case (ts, n, (n2, wrap)) => (n2, new AnonSymbol) }.toMap
val oldDefsM = hoisted.map { case (ts, n, (n2, wrap)) => (ts, wrap(Select(Ref(rsm.generator), newDefsM(n2)))) }.toMap
val bind2 = rewriteDBSide(Bind(s2, from2, Pure(StructNode(newDefsM.map(_.swap).toVector), new AnonTypeSymbol)).infer())
val rsm2 = rsm.copy(from = bind2, map = rsm.map.replace {
case Select(Ref(s), f) if s == rsm.generator => oldDefsM(f)
}).infer()
logger.debug("New ResultSetMapping:", Ellipsis(rsm2, List(0, 0)))
rsm2
case _ =>
val from2 = rewriteDBSide(from1)
if(from2 eq rsm.from) rsm else rsm.copy(from = from2).infer()
}
}
})

/** Fuse nested ResultSetMappings. Only the outer one may contain nested
* structures. Inner ResultSetMappings must produce a linearized
* ProductNode. */
def fuseResultSetMappings(rsm: ResultSetMapping): ResultSetMapping = rsm.from match {
case ResultSetMapping(gen2, from2, StructNode(ch2)) =>
logger.debug("Fusing ResultSetMapping:", rsm)
val ch2m = ch2.toMap
val nmap = rsm.map.replace({
case Select(Ref(sym), ElementSymbol(idx)) if sym == rsm.generator => ch2(idx-1)._2
case Select(Ref(sym), f) if sym == rsm.generator => ch2m(f)
case n @ Library.SilentCast(ch :@ tpe2) :@ tpe =>
if(tpe.structural == tpe2.structural) ch else {
logger.debug(s"SilentCast cannot be elided: $tpe != $tpe2")
n
}
}, bottomUp = true, keepType = true)
fuseResultSetMappings(ResultSetMapping(gen2, from2, nmap))
case n => rsm
}
/** Pull Bind nodes up to the top level through Filter and CollectionCast. */
def shuffle(n: Node): Node = n match {
case n @ Bind(s1, from1, sel1) =>
shuffle(from1) match {
case bind2 @ Bind(s2, from2, sel2 @ Pure(StructNode(elems2), ts2)) if !from2.isInstanceOf[GroupBy] =>
logger.debug("Merging top-level Binds", Ellipsis(n.copy(from = bind2), List(0,0)))
val defs = elems2.toMap
bind2.copy(select = sel1.replace {
case Select(Ref(s), f) if s == s1 => defs(f)
}).infer()
case from2 =>
if(from2 eq from1) n else n.copy(from = from2) :@ n.nodeType
}

def hoist(tree: Node): Node = {
logger.debug("Hoisting in:", tree)
val defs = tree.collectAll[(TermSymbol, Option[(Node => Node)])] { case StructNode(ch) =>
ch.map { case (s, n) =>
val u = unwrap(n)
logger.debug("Unwrapped "+n+" to "+u)
if(u._1 eq n) (s, None) else (s, Some(u._2))
// Push CollectionCast down unless it casts from a collection without duplicates to one with duplicates.
//TODO: Identity mappings are reversible, to we can safely allow them for any kind of conversion.
case n @ CollectionCast(from1 :@ CollectionType(cons1, _), cons2) if !cons1.isUnique || cons2.isUnique =>
shuffle(from1) match {
case Bind(s1, bfrom1, sel1 @ Pure(StructNode(elems1), ts1)) if !bfrom1.isInstanceOf[GroupBy] =>
val res = Bind(s1, CollectionCast(bfrom1, cons2), sel1.replace { case Ref(s) if s == s1 => Ref(s) }).infer()
logger.debug("Pulled Bind out of CollectionCast", Ellipsis(res, List(0,0)))
res
case from2 => if(from2 eq from1) n else n.copy(child = from2) :@ n.nodeType
}
}.collect { case (s, Some(u)) => (s, u) }.toMap
logger.debug("Unwrappable defs: "+defs)

if(defs.isEmpty) tree else {
lazy val tr: PartialFunction[Node, Node] = {
case p @ Path(elems @ (h :: _)) if defs.contains(h) =>
defs(h).apply(Path(elems)) // wrap an untyped copy
case d: DefNode => d.mapScopedChildren {
case (Some(sym), n) if defs.contains(sym) =>
unwrap(n)._1.replace(tr)
case (_, n) => n.replace(tr)
}
case n @ Filter(s1, from1, pred1) =>
shuffle(from1) match {
case from2 @ Bind(bs1, bfrom1, sel1 @ Pure(StructNode(elems1), ts1)) if !bfrom1.isInstanceOf[GroupBy] =>
logger.debug("Pulling Bind out of Filter", Ellipsis(n.copy(from = from2), List(0, 0)))
val s3 = new AnonSymbol
val defs = elems1.toMap
val res = Bind(bs1, Filter(s3, bfrom1, pred1.replace {
case Select(Ref(s), f) if s == s1 => defs(f).replace { case Ref(s) if s == bs1 => Ref(s3) }
}), sel1.replace { case Ref(s) if s == bs1 => Ref(s) })
logger.debug("Pulled Bind out of Filter", Ellipsis(res, List(0,0)))
res.infer()
case from2 =>
if(from2 eq from1) n else n.copy(from = from2) :@ n.nodeType
}
tree.replace(tr)
}

case n => n
}

/** Remove a hoistable operation from a top-level column and create a function to
* reapply it at the client side. */
def unwrap(n: Node): (Node, (Node => Node)) = n match {
case GetOrElse(ch, default) =>
val (recCh, recTr) = unwrap(ch)
(recCh, { sym => GetOrElse(recTr(sym), default) })
case OptionApply(ch) =>
val (recCh, recTr) = unwrap(ch)
(recCh, { sym => OptionApply(recTr(sym)) })
case n => (n, identity)
}

/** Rewrite remaining `GetOrElse` operations in the server-side tree into conditionals. */
def rewriteDBSide(tree: Node): Node = tree match {
case GetOrElse(ch, default) =>
val d = try default() catch {
Expand All @@ -80,14 +100,4 @@ class HoistClientOps extends Phase {
Library.IfNull.typed(tpe, ch2, LiteralNode.apply(tpe, d)).infer()
case n => n.mapChildren(rewriteDBSide, keepType = true)
}

def unwrap(n: Node): (Node, (Node => Node)) = n match {
case GetOrElse(ch, default) =>
val (recCh, recTr) = unwrap(ch)
(recCh, { sym => GetOrElse(recTr(sym), default) })
case OptionApply(ch) =>
val (recCh, recTr) = unwrap(ch)
(recCh, { sym => OptionApply(recTr(sym)) })
case n => (n, identity)
}
}
15 changes: 11 additions & 4 deletions slick/src/main/scala/slick/compiler/MergeToComprehensions.scala
Expand Up @@ -25,7 +25,9 @@ class MergeToComprehensions extends Phase {

type Mappings = Seq[((TypeSymbol, TermSymbol), List[TermSymbol])]

def apply(state: CompilerState) = state.map(convert)
def apply(state: CompilerState) = state.map(n => ClientSideOp.mapResultSetMapping(n, keepType = false) { rsm =>
rsm.copy(from = convert(rsm.from), map = rsm.map.replace { case r: Ref => r.untyped })
}.infer())

def convert(tree: Node): Node = {
// Find all references into tables so we can convert TableNodes to Comprehensions
Expand Down Expand Up @@ -233,6 +235,8 @@ class MergeToComprehensions extends Phase {
}

def convert1(n: Node): Node = n match {
case CollectionCast(_, _) =>
n.mapChildren(convert1, keepType = true)
case n :@ Type.Structural(CollectionType(cons, el)) =>
convertOnlyInScalar(createTopLevel(n)._1)
case a: Aggregate =>
Expand All @@ -254,10 +258,13 @@ class MergeToComprehensions extends Phase {
case n => convert1(n)
}

convert1(tree)
val tree2 :@ CollectionType(cons2, _) = convert1(tree)
val cons1 = tree.nodeType.asCollectionType.cons
if(cons2 != cons1) CollectionCast(tree2, cons1).infer()
else tree2
}

/** Merge the common operations Bind, Filter and CollectionBase into an existing Comprehension.
/** Merge the common operations Bind, Filter and CollectionCast into an existing Comprehension.
* This method is used at different stages of the pipeline. */
def mergeCommon(rec: (Node, Boolean) => (Comprehension, Replacements), parent: (Node, Boolean) => (Comprehension, Replacements),
n: Node, buildBase: Boolean,
Expand All @@ -277,7 +284,7 @@ class MergeToComprehensions extends Phase {
logger.debug("Merging Filter into Comprehension:", Ellipsis(n, List(0)))
val p2 = applyReplacements(p1, replacements1, c1)
val c2 =
if(c1.groupBy.isEmpty) c1.copy(where = Some(c1.where.fold(p2)(and(p2, _)).infer())) :@ c1.nodeType
if(c1.groupBy.isEmpty) c1.copy(where = Some(c1.where.fold(p2)(and(_, p2)).infer())) :@ c1.nodeType
else c1.copy(having = Some(c1.having.fold(p2)(and(p2, _)).infer())) :@ c1.nodeType
logger.debug("Merged Filter into Comprehension:", c2)
(c2, replacements1)
Expand Down

0 comments on commit aaf2a22

Please sign in to comment.