Skip to content

Commit

Permalink
Scalar optimization & better discriminator column picking:
Browse files Browse the repository at this point in the history
- A new compiler phase `optimizeScalar` performs the required local
  optimizations to eliminate unnecessary null checks arising from
  outer joins after `expandSums`.

- `expandSums` now keeps track of paths referenced within OptionApply
  nodes in Join and Filter predicates. These are preferred over other
  fields when picking a discriminator column for an outer join. The
  goal is to avoid NVL2 checks of the discriminator column in Join and
  Filter conditions (where this could prevent the use of an index).
  In the long term this may not be enough. A more complex alternative
  would be to pick *all* possible discriminators in `expandSums` and
  narrow them down to the best one in a later phase on a case by case
  basis.
  • Loading branch information
szeiger committed Aug 21, 2015
1 parent b01e539 commit 1c9469f
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 10 deletions.
1 change: 1 addition & 0 deletions common-test-resources/logback.xml
Expand Up @@ -31,6 +31,7 @@
<logger name="slick.compiler.HoistClientOps" level="${log.qcomp.hoistClientOps:-inherited}" />
<logger name="slick.compiler.ReorderOperations" level="${log.qcomp.reorderOperations:-inherited}" />
<logger name="slick.compiler.MergeToComprehensions" level="${log.qcomp.mergeToComprehensions:-inherited}" />
<logger name="slick.compiler.OptimizeScalar" level="${log.qcomp.optimizeScalar:-inherited}" />
<logger name="slick.compiler.FixRowNumberOrdering" level="${log.qcomp.fixRowNumberOrdering:-inherited}" />
<logger name="slick.compiler.PruneProjections" level="${log.qcomp.pruneProjections:-inherited}" />
<logger name="slick.compiler.RewriteBooleans" level="${log.qcomp.rewriteBooleans:-inherited}" />
Expand Down
1 change: 1 addition & 0 deletions slick-testkit/src/doctest/resources/logback.xml
Expand Up @@ -31,6 +31,7 @@
<logger name="slick.compiler.HoistClientOps" level="${log.qcomp.hoistClientOps:-inherited}" />
<logger name="slick.compiler.ReorderOperations" level="${log.qcomp.reorderOperations:-inherited}" />
<logger name="slick.compiler.MergeToComprehensions" level="${log.qcomp.mergeToComprehensions:-inherited}" />
<logger name="slick.compiler.OptimizeScalar" level="${log.qcomp.optimizeScalar:-inherited}" />
<logger name="slick.compiler.FixRowNumberOrdering" level="${log.qcomp.fixRowNumberOrdering:-inherited}" />
<logger name="slick.compiler.PruneProjections" level="${log.qcomp.pruneProjections:-inherited}" />
<logger name="slick.compiler.RewriteBooleans" level="${log.qcomp.rewriteBooleans:-inherited}" />
Expand Down
56 changes: 46 additions & 10 deletions slick/src/main/scala/slick/compiler/ExpandSums.scala
Expand Up @@ -5,18 +5,26 @@ import slick.ast._
import Util._
import TypeUtil._

import scala.collection.mutable

/** Expand sum types and their catamorphisms to equivalent product type operations. */
class ExpandSums extends Phase {
val name = "expandSums"

def apply(state: CompilerState) = state.map(tr)
def apply(state: CompilerState) = state.map(n => tr(n, Set.empty))

val Disc1 = LiteralNode(ScalaBaseType.optionDiscType.optionType, Option(1))
val DiscNone = LiteralNode(ScalaBaseType.optionDiscType.optionType, None)

/** Perform the sum expansion on a Node */
def tr(tree: Node): Node = {
val tree2 = tree.mapChildren(tr, keepType = true)
def tr(tree: Node, oldDiscCandidates: Set[(TypeSymbol, List[TermSymbol])]): Node = {
val discCandidates = oldDiscCandidates ++ (tree match {
case Filter(_, _, p) => collectDiscriminatorCandidates(p)
case Bind(_, j: Join, _) => collectDiscriminatorCandidates(j.on)
case _ => Set.empty
})

val tree2 = tree.mapChildren(n => tr(n, discCandidates), keepType = true)
val tree3 = tree2 match {
// Expand multi-column null values in ELSE branches (used by Rep[Option].filter) with correct type
case IfThenElse(IndexedSeq(pred, then1 :@ tpe, LiteralNode(None) :@ OptionType(ScalaBaseType.nullType))) =>
Expand Down Expand Up @@ -73,7 +81,7 @@ class ExpandSums extends Phase {

// Option-extended left outer, right outer or full outer join
case bind @ Bind(bsym, Join(_, _, _, _, jt, _), _) if jt == JoinType.LeftOption || jt == JoinType.RightOption || jt == JoinType.OuterOption =>
translateJoin(bind)
translateJoin(bind, discCandidates)

case n => n
}
Expand All @@ -82,7 +90,7 @@ class ExpandSums extends Phase {
}

/** Translate an Option-extended left outer, right outer or full outer join */
def translateJoin(bind: Bind): Bind = {
def translateJoin(bind: Bind, discCandidates: Set[(TypeSymbol, List[TermSymbol])]): Bind = {
logger.debug("translateJoin", bind)
val Bind(bsym, (join @ Join(lsym, rsym, left :@ CollectionType(_, leftElemType), right :@ CollectionType(_, rightElemType), jt, on)) :@ CollectionType(cons, elemType), pure) = bind
val lComplex = leftElemType.structural.children.nonEmpty
Expand All @@ -91,17 +99,29 @@ class ExpandSums extends Phase {

// Find an existing column that can serve as a discriminator
def findDisc(t: Type): Option[List[TermSymbol]] = {
val global: Set[List[TermSymbol]] = t match {
case NominalType(ts, exp) =>
val c = discCandidates.filter { case (t, ss) => t == ts && ss.nonEmpty }.map(_._2)
logger.debug("Discriminator candidates from surrounding Filter and Join predicates: "+
c.map(Path.toString).mkString(", "))
c
case _ => Set.empty
}
def find(t: Type, path: List[TermSymbol]): Vector[List[TermSymbol]] = t.structural match {
case StructType(defs) => defs.flatMap { case (s, t) => find(t, s :: path) }(collection.breakOut)
case p: ProductType => p.numberedElements.flatMap { case (s, t) => find(t, s :: path) }.toVector
case _: AtomicType => Vector(path)
case _ => Vector.empty
}
find(t, Nil).sortBy(ss => ss.head match {
case f: FieldSymbol =>
if(f.options contains ColumnOption.PrimaryKey) -2 else -1
case _ => 0
}).headOption
val local = find(t, Nil).sortBy { ss =>
(if(global contains ss) 3 else 1) * (ss.head match {
case f: FieldSymbol =>
if(f.options contains ColumnOption.PrimaryKey) -2 else -1
case _ => 0
})
}
logger.debug("Local candidates: "+local.map(Path.toString).mkString(", "))
local.headOption
}

// Option-extend one side of the join with a discriminator column
Expand Down Expand Up @@ -207,4 +227,20 @@ class ExpandSums extends Phase {
ProductNode(Vector(disc, map)).infer()
case n => n
}

/** Collect discriminator candidate fields in a predicate. These are all paths below an
* OptionApply, which indicates their future use under a discriminator guard. */
def collectDiscriminatorCandidates(n: Node): Set[(TypeSymbol, List[TermSymbol])] = n.collectAll[(TypeSymbol, List[TermSymbol])] {
case OptionApply(ch) =>
ch.collect[(TypeSymbol, List[TermSymbol])] { case PathOnTypeSymbol(ts, ss) => (ts, ss) }
}.toSet

object PathOnTypeSymbol {
def unapply(n: Node): Option[(TypeSymbol, List[TermSymbol])] = n match {
case (n: PathElement) :@ NominalType(ts, _) => Some((ts, Nil))
case Select(in, s) => unapply(in).map { case (ts, l) => (ts, s :: l) }
case Library.SilentCast(ch) => unapply(ch)
case _ => None
}
}
}
54 changes: 54 additions & 0 deletions slick/src/main/scala/slick/compiler/OptimizeScalar.scala
@@ -0,0 +1,54 @@
package slick.compiler

import slick.ast.TypeUtil._
import slick.ast.Util._
import slick.ast._

/** Optimize scalar expressions */
class OptimizeScalar extends Phase {
val name = "optimizeScalar"

def apply(state: CompilerState) = state.map(_.tree.replace({
// (if(p) a else b) == v
case n @ Library.==(IfThenElse(Seq(p, Const(a), Const(b))), Const(v)) =>
val checkTrue = v == a
val checkFalse = v == b
val res =
if(checkTrue && checkFalse) LiteralNode(true)
else if(checkTrue && !checkFalse) p
else if(checkFalse) Library.Not.typed(p.nodeType, p)
else LiteralNode(false)
cast(n.nodeType, res).infer()

// if(v != null) v else null
case n @ IfThenElse(Seq(Library.Not(Library.==(v, LiteralNode(null))), v2, LiteralNode(z)))
if v == v2 && (z == null || z == None) =>
v

// Redundant cast to non-nullable within OptionApply
case o @ OptionApply(Library.SilentCast(n)) if o.nodeType == n.nodeType => n

// Rownum comparison with offset 1, arising from zipWithIndex
case n @ Library.<(Library.-(r: RowNumber, LiteralNode(1L)), v) =>
Library.<=.typed(n.nodeType, r, v).infer()

// Some(v).getOrElse(_)
case n @ Library.IfNull(OptionApply(ch), _) =>
cast(n.nodeType, ch)

}, keepType = true, bottomUp = true))

object Const {
def unapply(n: Node): Option[Node] = n match {
case _: LiteralNode => Some(n)
case Apply(Library.SilentCast, Seq(ch)) => unapply(ch)
case OptionApply(ch) => unapply(ch)
case _ => None
}
}

def cast(tpe: Type, n: Node): Node = {
val n2 = n.infer()
if(n2.nodeType == tpe) n2 else Library.SilentCast.typed(tpe, n2)
}
}
2 changes: 2 additions & 0 deletions slick/src/main/scala/slick/compiler/QueryCompiler.scala
Expand Up @@ -132,6 +132,7 @@ object QueryCompiler {
Phase.hoistClientOps,
Phase.reorderOperations,
Phase.mergeToComprehensions,
Phase.optimizeScalar,
Phase.fixRowNumberOrdering,
Phase.removeFieldNames
// optional rewriteBooleans goes here
Expand Down Expand Up @@ -186,6 +187,7 @@ object Phase {
val reorderOperations = new ReorderOperations
val relabelUnions = new RelabelUnions
val mergeToComprehensions = new MergeToComprehensions
val optimizeScalar = new OptimizeScalar
val fixRowNumberOrdering = new FixRowNumberOrdering
val pruneProjections = new PruneProjections
val removeFieldNames = new RemoveFieldNames
Expand Down

0 comments on commit 1c9469f

Please sign in to comment.