Permalink
Browse files

Scalar optimization & better discriminator column picking:

- A new compiler phase `optimizeScalar` performs the required local
  optimizations to eliminate unnecessary null checks arising from
  outer joins after `expandSums`.

- `expandSums` now keeps track of paths referenced within OptionApply
  nodes in Join and Filter predicates. These are preferred over other
  fields when picking a discriminator column for an outer join. The
  goal is to avoid NVL2 checks of the discriminator column in Join and
  Filter conditions (where this could prevent the use of an index).
  In the long term this may not be enough. A more complex alternative
  would be to pick *all* possible discriminators in `expandSums` and
  narrow them down to the best one in a later phase on a case by case
  basis.
  • Loading branch information...
szeiger committed Aug 21, 2015
1 parent b01e539 commit 1c9469f9b7d699ab677ed0869a6bc1ab2f5fdd87
@@ -31,6 +31,7 @@
<logger name="slick.compiler.HoistClientOps" level="${log.qcomp.hoistClientOps:-inherited}" />
<logger name="slick.compiler.ReorderOperations" level="${log.qcomp.reorderOperations:-inherited}" />
<logger name="slick.compiler.MergeToComprehensions" level="${log.qcomp.mergeToComprehensions:-inherited}" />
<logger name="slick.compiler.OptimizeScalar" level="${log.qcomp.optimizeScalar:-inherited}" />
<logger name="slick.compiler.FixRowNumberOrdering" level="${log.qcomp.fixRowNumberOrdering:-inherited}" />
<logger name="slick.compiler.PruneProjections" level="${log.qcomp.pruneProjections:-inherited}" />
<logger name="slick.compiler.RewriteBooleans" level="${log.qcomp.rewriteBooleans:-inherited}" />
@@ -31,6 +31,7 @@
<logger name="slick.compiler.HoistClientOps" level="${log.qcomp.hoistClientOps:-inherited}" />
<logger name="slick.compiler.ReorderOperations" level="${log.qcomp.reorderOperations:-inherited}" />
<logger name="slick.compiler.MergeToComprehensions" level="${log.qcomp.mergeToComprehensions:-inherited}" />
<logger name="slick.compiler.OptimizeScalar" level="${log.qcomp.optimizeScalar:-inherited}" />
<logger name="slick.compiler.FixRowNumberOrdering" level="${log.qcomp.fixRowNumberOrdering:-inherited}" />
<logger name="slick.compiler.PruneProjections" level="${log.qcomp.pruneProjections:-inherited}" />
<logger name="slick.compiler.RewriteBooleans" level="${log.qcomp.rewriteBooleans:-inherited}" />
@@ -5,18 +5,26 @@ import slick.ast._
import Util._
import TypeUtil._
import scala.collection.mutable
/** Expand sum types and their catamorphisms to equivalent product type operations. */
class ExpandSums extends Phase {
val name = "expandSums"
def apply(state: CompilerState) = state.map(tr)
def apply(state: CompilerState) = state.map(n => tr(n, Set.empty))
val Disc1 = LiteralNode(ScalaBaseType.optionDiscType.optionType, Option(1))
val DiscNone = LiteralNode(ScalaBaseType.optionDiscType.optionType, None)
/** Perform the sum expansion on a Node */
def tr(tree: Node): Node = {
val tree2 = tree.mapChildren(tr, keepType = true)
def tr(tree: Node, oldDiscCandidates: Set[(TypeSymbol, List[TermSymbol])]): Node = {
val discCandidates = oldDiscCandidates ++ (tree match {
case Filter(_, _, p) => collectDiscriminatorCandidates(p)
case Bind(_, j: Join, _) => collectDiscriminatorCandidates(j.on)
case _ => Set.empty
})
val tree2 = tree.mapChildren(n => tr(n, discCandidates), keepType = true)
val tree3 = tree2 match {
// Expand multi-column null values in ELSE branches (used by Rep[Option].filter) with correct type
case IfThenElse(IndexedSeq(pred, then1 :@ tpe, LiteralNode(None) :@ OptionType(ScalaBaseType.nullType))) =>
@@ -73,7 +81,7 @@ class ExpandSums extends Phase {
// Option-extended left outer, right outer or full outer join
case bind @ Bind(bsym, Join(_, _, _, _, jt, _), _) if jt == JoinType.LeftOption || jt == JoinType.RightOption || jt == JoinType.OuterOption =>
translateJoin(bind)
translateJoin(bind, discCandidates)
case n => n
}
@@ -82,7 +90,7 @@ class ExpandSums extends Phase {
}
/** Translate an Option-extended left outer, right outer or full outer join */
def translateJoin(bind: Bind): Bind = {
def translateJoin(bind: Bind, discCandidates: Set[(TypeSymbol, List[TermSymbol])]): Bind = {
logger.debug("translateJoin", bind)
val Bind(bsym, (join @ Join(lsym, rsym, left :@ CollectionType(_, leftElemType), right :@ CollectionType(_, rightElemType), jt, on)) :@ CollectionType(cons, elemType), pure) = bind
val lComplex = leftElemType.structural.children.nonEmpty
@@ -91,17 +99,29 @@ class ExpandSums extends Phase {
// Find an existing column that can serve as a discriminator
def findDisc(t: Type): Option[List[TermSymbol]] = {
val global: Set[List[TermSymbol]] = t match {
case NominalType(ts, exp) =>
val c = discCandidates.filter { case (t, ss) => t == ts && ss.nonEmpty }.map(_._2)
logger.debug("Discriminator candidates from surrounding Filter and Join predicates: "+
c.map(Path.toString).mkString(", "))
c
case _ => Set.empty
}
def find(t: Type, path: List[TermSymbol]): Vector[List[TermSymbol]] = t.structural match {
case StructType(defs) => defs.flatMap { case (s, t) => find(t, s :: path) }(collection.breakOut)
case p: ProductType => p.numberedElements.flatMap { case (s, t) => find(t, s :: path) }.toVector
case _: AtomicType => Vector(path)
case _ => Vector.empty
}
find(t, Nil).sortBy(ss => ss.head match {
case f: FieldSymbol =>
if(f.options contains ColumnOption.PrimaryKey) -2 else -1
case _ => 0
}).headOption
val local = find(t, Nil).sortBy { ss =>
(if(global contains ss) 3 else 1) * (ss.head match {
case f: FieldSymbol =>
if(f.options contains ColumnOption.PrimaryKey) -2 else -1
case _ => 0
})
}
logger.debug("Local candidates: "+local.map(Path.toString).mkString(", "))
local.headOption
}
// Option-extend one side of the join with a discriminator column
@@ -207,4 +227,20 @@ class ExpandSums extends Phase {
ProductNode(Vector(disc, map)).infer()
case n => n
}
/** Collect discriminator candidate fields in a predicate. These are all paths below an
* OptionApply, which indicates their future use under a discriminator guard. */
def collectDiscriminatorCandidates(n: Node): Set[(TypeSymbol, List[TermSymbol])] = n.collectAll[(TypeSymbol, List[TermSymbol])] {
case OptionApply(ch) =>
ch.collect[(TypeSymbol, List[TermSymbol])] { case PathOnTypeSymbol(ts, ss) => (ts, ss) }
}.toSet
object PathOnTypeSymbol {
def unapply(n: Node): Option[(TypeSymbol, List[TermSymbol])] = n match {
case (n: PathElement) :@ NominalType(ts, _) => Some((ts, Nil))
case Select(in, s) => unapply(in).map { case (ts, l) => (ts, s :: l) }
case Library.SilentCast(ch) => unapply(ch)
case _ => None
}
}
}
@@ -0,0 +1,54 @@
package slick.compiler
import slick.ast.TypeUtil._
import slick.ast.Util._
import slick.ast._
/** Optimize scalar expressions */
class OptimizeScalar extends Phase {
val name = "optimizeScalar"
def apply(state: CompilerState) = state.map(_.tree.replace({
// (if(p) a else b) == v
case n @ Library.==(IfThenElse(Seq(p, Const(a), Const(b))), Const(v)) =>
val checkTrue = v == a
val checkFalse = v == b
val res =
if(checkTrue && checkFalse) LiteralNode(true)
else if(checkTrue && !checkFalse) p
else if(checkFalse) Library.Not.typed(p.nodeType, p)
else LiteralNode(false)
cast(n.nodeType, res).infer()
// if(v != null) v else null
case n @ IfThenElse(Seq(Library.Not(Library.==(v, LiteralNode(null))), v2, LiteralNode(z)))
if v == v2 && (z == null || z == None) =>
v
// Redundant cast to non-nullable within OptionApply
case o @ OptionApply(Library.SilentCast(n)) if o.nodeType == n.nodeType => n
// Rownum comparison with offset 1, arising from zipWithIndex
case n @ Library.<(Library.-(r: RowNumber, LiteralNode(1L)), v) =>
Library.<=.typed(n.nodeType, r, v).infer()
// Some(v).getOrElse(_)
case n @ Library.IfNull(OptionApply(ch), _) =>
cast(n.nodeType, ch)
}, keepType = true, bottomUp = true))
object Const {
def unapply(n: Node): Option[Node] = n match {
case _: LiteralNode => Some(n)
case Apply(Library.SilentCast, Seq(ch)) => unapply(ch)
case OptionApply(ch) => unapply(ch)
case _ => None
}
}
def cast(tpe: Type, n: Node): Node = {
val n2 = n.infer()
if(n2.nodeType == tpe) n2 else Library.SilentCast.typed(tpe, n2)
}
}
@@ -132,6 +132,7 @@ object QueryCompiler {
Phase.hoistClientOps,
Phase.reorderOperations,
Phase.mergeToComprehensions,
Phase.optimizeScalar,
Phase.fixRowNumberOrdering,
Phase.removeFieldNames
// optional rewriteBooleans goes here
@@ -186,6 +187,7 @@ object Phase {
val reorderOperations = new ReorderOperations
val relabelUnions = new RelabelUnions
val mergeToComprehensions = new MergeToComprehensions
val optimizeScalar = new OptimizeScalar
val fixRowNumberOrdering = new FixRowNumberOrdering
val pruneProjections = new PruneProjections
val removeFieldNames = new RemoveFieldNames

0 comments on commit 1c9469f

Please sign in to comment.