Permalink
Browse files

Eliminate explicit outer join discriminator columns where possible

- In `expandSums` we now check for a suitable existing column in the
  type of the join side to expand, and use this as the discriminator
  instead of generating an additional column that always leads to a
  subquery in SQL ("select 1, ..."). Suitable columns must be of a
  primitive non-Option type. Primary keys are preferred over other
  fields and those over computed columns. A conditional expression in
  the wrapping Bind converts the column to the proper discriminator
  type `Option[Int]`.

- These conditional expressions are eliminated in `hoistClientOps` if
  they occur at the top level, passing the substitute discriminator
  directly to the client side.

- A more flexible `IsDefinedResultConverter` is used for reading
  discriminator columns of any Option type (only NULL vs non-NULL
  matters). JdbcProfile uses a specialized, fused version.

- Computing the new StructNodes in `flattenProjections` now removes
  duplicate definitions, except at the top level (where the projection
  has to match the translated top-level type) and directly below a
  Union (where both sides have to match up).

- `hoistClientOps` can now pull operations out of the non-Option sides
  of joins which is required for avoiding subqueries in the case of
  nested outer joins.

Tests in NewQuerySemanticsTest.testNewFusion, UnionTest.testBasicUnions.
Fixes #1241.
  • Loading branch information...
szeiger committed Aug 18, 2015
1 parent 6b92d09 commit b87c994a4ed6cd253a4e8672e46c56aface0958b
@@ -469,7 +469,7 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
def testNewFusion = {
class A(tag: Tag) extends Table[(Int, String, String)](tag, "A_NEWFUSION") {
def id = column[Int]("id")
def id = column[Int]("id", O.PrimaryKey)
def a = column[String]("a")
def b = column[String]("b")
def * = (id, a, b)
@@ -503,6 +503,8 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
val q15 = (as.map(a => a.id.?).filter(_ < 2) unionAll as.map(a => a.id.?).filter(_ > 2)).map(_.get).to[Set]
val q16 = (as.map(a => a.id.?).filter(_ < 2) unionAll as.map(a => a.id.?).filter(_ > 2)).map(_.getOrElse(-1)).to[Set].filter(_ =!= 42)
val q17 = as.sortBy(_.id).zipWithIndex.filter(_._2 < 2L).map { case (a, i) => (a.id, i) }
val q18 = as.joinLeft(as).on { case (a1, a2) => a1.id === a2.id }.filter { case (a1, a2) => a1.id === 3 }.map { case (a1, a2) => a2 }
val q19 = as.joinLeft(as).on { case (a1, a2) => a1.id === a2.id }.joinLeft(as).on { case ((_, a2), a3) => a2.map(_.b) === a3.b }.map(_._2)
if(tdb.driver == H2Driver) {
assertNesting(q1, 1)
@@ -530,6 +532,8 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
assertNesting(q15, 2)
assertNesting(q16, 2)
assertNesting(q17, 2)
assertNesting(q18, 1)
assertNesting(q19, 1)
}
for {
@@ -561,6 +565,8 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
_ <- mark("q15", q15.result).map(_ shouldBe Set(1, 3))
_ <- mark("q16", q16.result).map(_ shouldBe Set(1, 3))
_ <- ifCap(rcap.zip)(mark("q17", q17.result).map(_ shouldBe Seq((1,0), (2,1))))
_ <- mark("q18", q18.result).map(_ shouldBe Seq(Some((3, "c", "b"))))
_ <- mark("q19", q19.result).map(_.toSet shouldBe Set(Some((1,"a","a")), Some((2,"a","b")), Some((3,"c","b"))))
} yield ()
}
@@ -24,13 +24,14 @@ class UnionTest extends AsyncTest[RelationalTestDB] {
}
lazy val employees = TableQuery[Employees]
def testBasic = {
def testBasicUnions = {
val q1 = for(m <- managers filter { _.department === "IT" }) yield (m.id, m.name)
val q2 = for(e <- employees filter { _.departmentIs("IT") }) yield (e.id, e.name)
val q3 = (q1 union q2).sortBy(_._2.asc)
val q4 = managers.map(_.id)
val q4b = q4 union q4
val q4c = q4 union q4 union q4
val q5 = managers.map(m => (m.id, 0)) union employees.map(e => (e.id, e.id))
(for {
_ <- (managers.schema ++ employees.schema).create
@@ -51,6 +52,7 @@ class UnionTest extends AsyncTest[RelationalTestDB] {
_ <- mark("q3", q3.result).map(_ shouldBe List((2,"Amy"), (7,"Ben"), (8,"Greg"), (6,"Leonard"), (3,"Steve")))
_ <- mark("q4b", q4b.result).map(r => r.toSet shouldBe Set(1, 2, 3))
_ <- mark("q4c", q4c.result).map(r => r.toSet shouldBe Set(1, 2, 3))
_ <- mark("q5", q5.result).map(r => r.toSet shouldBe Set((7,7), (6,6), (2,0), (4,4), (3,0), (8,8), (5,5), (1,0)))
} yield ()) andFinally (managers.schema ++ employees.schema).drop
}
@@ -48,7 +48,8 @@ class CreateResultSetMapping extends Phase {
case t: MappedScalaType =>
TypeMapping(f(t.baseType), t.mapper, t.classTag)
case o @ OptionType(Type.Structural(el)) if el.children.nonEmpty =>
val discriminator = f(ScalaBaseType.intType.optionType)
val discriminator = Select(ref, syms(curIdx)).infer()
curIdx += 1
val data = f(o.elementType)
RebuildOption(discriminator, data)
case t =>
@@ -89,35 +89,69 @@ class ExpandSums extends Phase {
val rComplex = rightElemType.structural.children.nonEmpty
logger.debug(s"Translating join ($jt, complex: $lComplex, $rComplex):", bind)
// Find an existing column that can serve as a discriminator
def findDisc(t: Type): Option[List[TermSymbol]] = {
def find(t: Type, path: List[TermSymbol]): Vector[List[TermSymbol]] = t.structural match {
case StructType(defs) => defs.flatMap { case (s, t) => find(t, s :: path) }(collection.breakOut)
case p: ProductType => p.numberedElements.flatMap { case (s, t) => find(t, s :: path) }.toVector
case _: AtomicType => Vector(path)
case _ => Vector.empty
}
find(t, Nil).sortBy(ss => ss.head match {
case f: FieldSymbol =>
if(f.options contains ColumnOption.PrimaryKey) -2 else -1
case _ => 0
}).headOption
}
// Option-extend one side of the join with a discriminator column
def extend(side: Node, sym: TermSymbol, on: Node): (Node, Node) = {
def extend(side: Node, sym: TermSymbol, on: Node): (Node, Node, Boolean) = {
val extendGen = new AnonSymbol
val extend :@ CollectionType(_, extendedElementType) = Bind(extendGen, side, Pure(ProductNode(Vector(Disc1, Ref(extendGen))))).infer()
val elemType = side.nodeType.asCollectionType.elementType
val (disc, createDisc) = findDisc(elemType) match {
case Some(path) =>
logger.debug("Using existing column "+Path(path)+" as discriminator in "+elemType)
(FwdPath(extendGen :: path.reverse), true)
case None =>
logger.debug("No suitable discriminator column found in "+elemType)
(Disc1, false)
}
val extend :@ CollectionType(_, extendedElementType) = Bind(extendGen, side, Pure(ProductNode(Vector(disc, Ref(extendGen))))).infer()
val sideInCondition = Select(Ref(sym) :@ extendedElementType, ElementSymbol(2)).infer()
val on2 = on.replace({
case Ref(s) if s == sym => sideInCondition
case n @ Select(in, _) => n.infer()
}, bottomUp = true).infer()
(extend, on2)
(extend, on2, createDisc)
}
// Translate the join depending on JoinType and Option type
val (left2, right2, on2, jt2) = jt match {
val (left2, right2, on2, jt2, ldisc, rdisc) = jt match {
case JoinType.LeftOption =>
val (right2, on2) = if(rComplex) extend(right, rsym, on) else (right, on)
(left, right2, on2, JoinType.Left)
val (right2, on2, rdisc) = if(rComplex) extend(right, rsym, on) else (right, on, false)
(left, right2, on2, JoinType.Left, false, rdisc)
case JoinType.RightOption =>
val (left2, on2) = if(lComplex) extend(left, lsym, on) else (left, on)
(left2, right, on2, JoinType.Right)
val (left2, on2, ldisc) = if(lComplex) extend(left, lsym, on) else (left, on, false)
(left2, right, on2, JoinType.Right, ldisc, false)
case JoinType.OuterOption =>
val (left2, on2) = if(lComplex) extend(left, lsym, on) else (left, on)
val (right2, on3) = if(rComplex) extend(right, rsym, on2) else (right, on2)
(left2, right2, on3, JoinType.Outer)
val (left2, on2, ldisc) = if(lComplex) extend(left, lsym, on) else (left, on, false)
val (right2, on3, rdisc) = if(rComplex) extend(right, rsym, on2) else (right, on2, false)
(left2, right2, on3, JoinType.Outer, ldisc, rdisc)
}
// Cast to translated Option type in outer bind
val join2 :@ CollectionType(_, elemType2) = Join(lsym, rsym, left2, right2, jt2, on2).infer()
val ref = silentCast(trType(elemType), Ref(bsym) :@ elemType2)
def optionCast(idx: Int, createDisc: Boolean): Node = {
val ref = Select(Ref(bsym) :@ elemType2, ElementSymbol(idx+1))
val v = if(createDisc) {
val protoDisc = Select(ref, ElementSymbol(1)).infer()
val rest = Select(ref, ElementSymbol(2))
val disc = IfThenElse(Vector(Library.==.typed[Boolean](silentCast(OptionType(protoDisc.nodeType), protoDisc), LiteralNode(null)), DiscNone, Disc1))
ProductNode(Vector(disc, rest))
} else ref
silentCast(trType(elemType.asInstanceOf[ProductType].children(idx)), v)
}
val ref = ProductNode(Vector(optionCast(0, ldisc), optionCast(1, rdisc))).infer()
val pure2 = pure.replace({
case Ref(s) if s == bsym => ref
@@ -14,10 +14,10 @@ class FlattenProjections extends Phase {
def apply(state: CompilerState) = state.map { tree =>
val translations = new HashMap[TypeSymbol, (Map[List[TermSymbol], TermSymbol], StructType)]
def tr(n: Node): Node = n match {
def tr(n: Node, topLevel: Boolean): Node = n match {
case Pure(v, ts) =>
logger.debug(s"Flattening projection $ts")
val (newV, newTranslations) = flattenProjection(tr(v))
val (newV, newTranslations) = flattenProjection(tr(v, false), !topLevel)
translations += ts -> (newTranslations, newV.nodeType.asInstanceOf[StructType])
logger.debug(s"Adding translation for $ts: ($newTranslations, ${newV.nodeType})")
val res = Pure(newV, ts).infer()
@@ -42,9 +42,13 @@ class FlattenProjections extends Phase {
}
logger.debug("Translated "+Path.toString(path)+" to:", p2)
p2
case n => n.mapChildren(tr)
case n: Bind =>
n.mapScopedChildren { case (o, ch) => tr(ch, topLevel && o.isEmpty) }
case u: Union =>
n.mapChildren { ch => tr(ch, true) }
case n => n.mapChildren(tr(_, false))
}
tr(tree).infer()
tr(tree, true).infer()
}
/** Split a path into the shortest part with a NominalType and the rest on
@@ -69,9 +73,15 @@ class FlattenProjections extends Phase {
}
}
/** Flatten a projection into a StructNode. */
def flattenProjection(n: Node): (StructNode, Map[List[TermSymbol], TermSymbol]) = {
/** Flatten a projection into a StructNode.
* @param collapse If set to true, duplicate definitions are combined into a single one. This
* must not be used in the top-level Bind because the definitions have to match the top-level
* type (which is used later in `createResultSetMapping`). Any duplicates there will be
* eliminated in `hoistClientOps`. It is also disabled directly under a Union because the
* columns on both sides have to match up. */
def flattenProjection(n: Node, collapse: Boolean): (StructNode, Map[List[TermSymbol], TermSymbol]) = {
val defs = new ArrayBuffer[(TermSymbol, Node)]
val defsM = new HashMap[Node, TermSymbol]
val paths = new HashMap[List[TermSymbol], TermSymbol]
def flatten(n: Node, path: List[TermSymbol]) {
logger.debug("Flattening node at "+Path.toString(path), n)
@@ -80,10 +90,17 @@ class FlattenProjections extends Phase {
case p: ProductNode =>
p.children.iterator.zipWithIndex.foreach { case (n, i) => flatten(n, new ElementSymbol(i+1) :: path) }
case n =>
val sym = new AnonSymbol
logger.debug(s"Adding definition: $sym -> $n")
defs += sym -> n
paths += path -> sym
defsM.get(n) match {
case Some(sym) if collapse =>
logger.debug(s"Reusing definition: $sym -> $n")
paths += path -> sym
case _ =>
val sym = new AnonSymbol
logger.debug(s"Adding definition: $sym -> $n")
defs += sym -> n
defsM += n -> sym
paths += path -> sym
}
}
}
flatten(n, Nil)
@@ -17,9 +17,10 @@ class HoistClientOps extends Phase {
from1 match {
case Bind(s2, from2, Pure(StructNode(defs2), ts2)) =>
// Extract client-side operations into ResultSetMapping
val hoisted = defs2.map { case (ts, n) => (ts, n, unwrap(n)) }
val hoisted = defs2.map { case (ts, n) => (ts, n, unwrap(n, true)) }
logger.debug("Hoisting operations from defs: " + hoisted.filter(t => t._2 ne t._3._1).map(_._1).mkString(", "))
val newDefsM = hoisted.map { case (ts, n, (n2, wrap)) => (n2, new AnonSymbol) }.toMap
logger.debug("New defs: "+newDefsM)
val oldDefsM = hoisted.map { case (ts, n, (n2, wrap)) => (ts, wrap(Select(Ref(rsm.generator), newDefsM(n2)))) }.toMap
val bind2 = rewriteDBSide(Bind(s2, from2, Pure(StructNode(newDefsM.map(_.swap).toVector), new AnonTypeSymbol)).infer())
val rsm2 = rsm.copy(from = bind2, map = rsm.map.replace {
@@ -37,12 +38,63 @@ class HoistClientOps extends Phase {
def shuffle(n: Node): Node = n match {
case n @ Bind(s1, from1, sel1) =>
shuffle(from1) match {
// Merge nested Binds
case bind2 @ Bind(s2, from2, sel2 @ Pure(StructNode(elems2), ts2)) if !from2.isInstanceOf[GroupBy] =>
logger.debug("Merging top-level Binds", Ellipsis(n.copy(from = bind2), List(0,0)))
val defs = elems2.toMap
bind2.copy(select = sel1.replace {
case Select(Ref(s), f) if s == s1 => defs(f)
}).infer()
// Hoist operations out of the non-Option sides of inner and left and right outer joins
case from2 @ Join(sl1, sr1, bl @ Bind(bsl, lfrom, Pure(StructNode(ldefs), tsl)),
br @ Bind(bsr, rfrom, Pure(StructNode(rdefs), tsr)),
jt, on1) if jt != JoinType.Outer =>
logger.debug("Hoisting operations from Join:", Ellipsis(from2, List(0, 0), List(1, 0)))
val (bl2: Bind, lrepl: Map[TermSymbol, (Node => Node, AnonSymbol)]) = if(jt != JoinType.Right) {
val hoisted = ldefs.map { case (ts, n) => (ts, n, unwrap(n, false)) }
logger.debug("Hoisting operations from defs in left side of Join: " + hoisted.filter(t => t._2 ne t._3._1).map(_._1).mkString(", "))
val newDefsM = hoisted.map { case (ts, n, (n2, wrap)) => (n2, new AnonSymbol) }.toMap
logger.debug("New defs: "+newDefsM)
val bl2 = bl.copy(select = Pure(StructNode(newDefsM.map(_.swap).toVector))).infer()
logger.debug("Translated left join side:", Ellipsis(bl2, List(0)))
val repl = hoisted.map { case (s, _, (n2, wrap)) => (s, (wrap, newDefsM(n2))) }.toMap
(bl2, repl)
} else (bl, Map.empty)
val (br2: Bind, rrepl: Map[TermSymbol, (Node => Node, AnonSymbol)]) = if(jt != JoinType.Left) {
val hoisted = rdefs.map { case (ts, n) => (ts, n, unwrap(n, false)) }
logger.debug("Hoisting operations from defs in right side of Join: " + hoisted.filter(t => t._2 ne t._3._1).map(_._1).mkString(", "))
val newDefsM = hoisted.map { case (ts, n, (n2, wrap)) => (n2, new AnonSymbol) }.toMap
logger.debug("New defs: "+newDefsM)
val br2 = br.copy(select = Pure(StructNode(newDefsM.map(_.swap).toVector))).infer()
logger.debug("Translated right join side:", Ellipsis(br2, List(0)))
val repl = hoisted.map { case (s, _, (n2, wrap)) => (s, (wrap, newDefsM(n2))) }.toMap
(br2, repl)
} else (br, Map.empty)
if((bl2 ne bl) || (br2 ne br)) {
val from3 = from2.copy(left = bl2, right = br2, on = on1.replace {
case Select(Ref(s), f) if s == sl1 && (bl2 ne bl) =>
val (wrap, f2) = lrepl(f)
wrap(Select(Ref(s), f2))
case Select(Ref(s), f) if s == sr1 && (br2 ne br) =>
val (wrap, f2) = rrepl(f)
wrap(Select(Ref(s), f2))
case Ref(s) if (s == sl1 && (bl2 ne bl)) || (s == sr1 && (br2 ne br)) =>
Ref(s)
})
val sel2 = sel1.replace {
case Select(Select(Ref(s), ElementSymbol(1)), f) if s == s1 && (bl2 ne bl) =>
val (wrap, f2) = lrepl(f)
wrap(Select(Select(Ref(s), ElementSymbol(1)), f2))
case Select(Select(Ref(s), ElementSymbol(2)), f) if s == s1 && (br2 ne br) =>
val (wrap, f2) = rrepl(f)
wrap(Select(Select(Ref(s), ElementSymbol(2)), f2))
case Ref(s) if s == s1 => Ref(s)
}
logger.debug("from3", from3)
logger.debug("sel2", sel2)
n.copy(from = from3, select = sel2).infer()
} else if(from2 eq from1) n
else n.copy(from = from2) :@ n.nodeType
case from2 =>
if(from2 eq from1) n else n.copy(from = from2) :@ n.nodeType
}
@@ -76,15 +128,23 @@ class HoistClientOps extends Phase {
case n => n
}
/** Remove a hoistable operation from a top-level column and create a function to
* reapply it at the client side. */
def unwrap(n: Node): (Node, (Node => Node)) = n match {
/** Remove a hoistable operation from a top-level column or join column and create a
* function to reapply it at an outer layer. */
def unwrap(n: Node, topLevel: Boolean): (Node, (Node => Node)) = n match {
case GetOrElse(ch, default) =>
val (recCh, recTr) = unwrap(ch)
val (recCh, recTr) = unwrap(ch, topLevel)
(recCh, { sym => GetOrElse(recTr(sym), default) })
case OptionApply(ch) =>
val (recCh, recTr) = unwrap(ch)
val (recCh, recTr) = unwrap(ch, topLevel)
(recCh, { sym => OptionApply(recTr(sym)) })
case IfThenElse(Seq(Library.==(ch, LiteralNode(null)), r1 @ LiteralNode(None), r2 @ LiteralNode(Some(1)))) :@ OptionType(t)
if t == ScalaBaseType.optionDiscType =>
val (recCh, recTr) = unwrap(ch, topLevel)
if(topLevel) (recCh, recTr)
else (recCh, { n => IfThenElse(Vector(Library.==.typed[Boolean](recTr(n), LiteralNode(null)), r1, r2)) })
case Library.SilentCast(ch) :@ tpe if !topLevel =>
val (recCh, recTr) = unwrap(ch, topLevel)
(recCh, { n => Library.SilentCast.typed(tpe, recTr(n)) })
case n => (n, identity)
}
@@ -164,9 +164,11 @@ class MergeToComprehensions extends Phase {
logger.debug("Mappings are: "+mappings)
Some((p, mappings))
case j @ Join(ls, rs, l1, r1, jt, on1) =>
logger.debug("Creating source from Join:", j)
logger.debug(s"Creating source from Join $ls/$rs:", j)
val (l2 @ (_ :@ CollectionType(_, ltpe)), lmap) = dealias(l1)(createSourceOrTopLevel)
val (r2 @ (_ :@ CollectionType(_, rtpe)), rmap) = dealias(r1)(createSourceOrTopLevel)
logger.debug(s"Converted left side of Join $ls/$rs:", l2)
logger.debug(s"Converted right side of Join $ls/$rs:", r2)
// Detect and remove empty join sides
val noCondition = on1 == LiteralNode(true).infer()
val noLeft = l2 match {
@@ -186,7 +188,7 @@ class MergeToComprehensions extends Phase {
lmap.map { case (key, ss) => (key, ElementSymbol(1) :: ss )} ++
rmap.map { case (key, ss) => (key, ElementSymbol(2) :: ss )}
val mappingsM = mappings.toMap
logger.debug("Mappings for `on` clause: "+mappingsM)
logger.debug(s"Mappings for `on` clause in Join $ls/$rs: "+mappingsM)
val on2 = on1.replace({
case p @ FwdPathOnTypeSymbol(ts, _ :: s :: Nil) =>
//logger.debug(s"Finding ($ts, $s)")
@@ -198,10 +200,10 @@ class MergeToComprehensions extends Phase {
}
}, bottomUp = true).infer(
scope = Type.Scope(j.leftGen -> l2.nodeType.asCollectionType.elementType) +
(j.rightGen -> r2.nodeType.asCollectionType.elementType))
logger.debug("Transformed `on` clause:", on2)
(j.rightGen -> r2.nodeType.asCollectionType.elementType))
logger.debug(s"Transformed `on` clause in Join $ls/$rs:", on2)
val j2 = j.copy(left = l2, right = r2, on = on2).infer()
logger.debug("Created source from Join:", j2)
logger.debug(s"Created source from Join $ls/$rs:", j2)
Some((j2, mappings))
}
case n => None
@@ -38,6 +38,11 @@ trait JdbcMappingCompilerComponent { driver: JdbcDriver =>
case _ => super.createGetOrElseResultConverter[T](rc, default)
}
override def createIsDefinedResultConverter[T](rc: ResultConverter[JdbcResultConverterDomain, Option[T]]) = rc match {
case rc: OptionResultConverter[_] => rc.isDefined
case _ => super.createIsDefinedResultConverter(rc)
}
override def createTypeMappingResultConverter(rc: ResultConverter[JdbcResultConverterDomain, Any], mapper: MappedScalaType.Mapper) = {
val tm = new TypeMappingResultConverter(rc, mapper.toBase, mapper.toMapped)
mapper.fastPath match {
Oops, something went wrong.

0 comments on commit b87c994

Please sign in to comment.