Permalink
Browse files

Selectively skip compiler phases:

- In `assignUniqueSymbols` where a full traversal of the AST is done
  anyway, we detect the presence of certain node types to set flags for
  Distinct, TypeMapping / MappedScalaType, AggregateFunctionSymbol and
  Option operations.

- Depending on these flags, the phases `createAggregates`, `expandSums`,
  `removeMappedTypes` and `rewriteDistinct` can be skipped entirely.

- Another flag is set in `expandSums` indicating whether there is
  anything left to be cleaned up by `expandConditionals`, otherwise this
  phase is also skipped.

This reduces the runtime for CompilerBenchmark by about 15%.
  • Loading branch information...
szeiger committed Sep 7, 2015
1 parent 79dc6d5 commit d286c173dae82a7908c556d19a905babe5f7fd8b
@@ -1,44 +1,70 @@
package slick.compiler
import slick.ast.Library.AggregateFunctionSymbol
import scala.collection.mutable.{HashSet, HashMap}
import slick.SlickException
import slick.ast._
import TypeUtil._
/** Ensure that all symbol definitions in a tree are unique. The same symbol
* can initially occur in multiple sub-trees when some part of a query is
* reused multiple times. This phase assigns new, uniqe symbols, so that
* later phases do not have to take scopes into account for identifying
* the source of a symbol. The rewriting is performed for both, term symbols
* and type symbols. */
/** Ensure that all symbol definitions in a tree are unique. The same symbol can initially occur in
* multiple sub-trees when some part of a query is reused multiple times. This phase assigns new,
* uniqe symbols, so that later phases do not have to take scopes into account for identifying the
* source of a symbol. The rewriting is performed for both, term symbols and type symbols.
*
* The phase state is a collection of flags depending on the presence or absence of certain node
* types in the AST. This information can be used to selectively skip later compiler phases when
* it is already known that there is nothing for them to translate.
*/
class AssignUniqueSymbols extends Phase {
val name = "assignUniqueSymbols"
def apply(state: CompilerState) = state.map { tree =>
val replace = new HashMap[TermSymbol, AnonSymbol]
def tr(n: Node): Node = {
val n3 = n match {
case Select(in, s) => Select(tr(in), s) :@ n.nodeType
case r @ Ref(a: AnonSymbol) =>
val s = replace.getOrElse(a, a)
if(s eq a) r else Ref(s)
case t: TableNode => t.copy(identity = new AnonTableIdentitySymbol)(t.driverTable)
case Pure(value, _) => Pure(tr(value))
case g: GroupBy =>
val d = g.copy(identity = new AnonTypeSymbol)
val a = new AnonSymbol
replace += g.fromGen -> a
g.copy(fromGen = a, tr(g.from), tr(g.by), identity = new AnonTypeSymbol)
case n: StructNode => n.mapChildren(tr)
case d: DefNode =>
replace ++= d.generators.iterator.map(_._1 -> new AnonSymbol)
d.mapSymbols(s => replace.getOrElse(s, s)).mapChildren(tr)
case n => n.mapChildren(tr)
type State = UsedFeatures
def apply(state: CompilerState) = {
var hasDistinct, hasTypeMapping, hasAggregate, hasNonPrimitiveOption = false
val s2 = state.map { tree =>
val replace = new HashMap[TermSymbol, AnonSymbol]
def checkFeatures(n: Node): Unit = n match {
case _: Distinct => hasDistinct = true
case _: TypeMapping => hasTypeMapping = true
case n: Apply =>
if(n.sym.isInstanceOf[AggregateFunctionSymbol]) hasAggregate = true
case (_: OptionFold | _: OptionApply | _: GetOrElse) => hasNonPrimitiveOption = true
case j: Join =>
if(j.jt == JoinType.LeftOption || j.jt == JoinType.RightOption || j.jt == JoinType.OuterOption) hasNonPrimitiveOption = true
case _ =>
}
// Remove all NominalTypes (which might have changed)
if(n3.hasType && hasNominalType(n3.nodeType)) n3.untyped else n3
def tr(n: Node): Node = {
val n3 = n match {
case Select(in, s) => Select(tr(in), s) :@ n.nodeType
case r @ Ref(a: AnonSymbol) =>
val s = replace.getOrElse(a, a)
if(s eq a) r else Ref(s)
case t: TableNode => t.copy(identity = new AnonTableIdentitySymbol)(t.driverTable)
case Pure(value, _) => Pure(tr(value))
case g: GroupBy =>
val d = g.copy(identity = new AnonTypeSymbol)
val a = new AnonSymbol
replace += g.fromGen -> a
g.copy(fromGen = a, tr(g.from), tr(g.by), identity = new AnonTypeSymbol)
case n: StructNode => n.mapChildren(tr)
case d: DefNode =>
checkFeatures(d)
replace ++= d.generators.iterator.map(_._1 -> new AnonSymbol)
d.mapSymbols(s => replace.getOrElse(s, s)).mapChildren(tr)
case n =>
checkFeatures(n)
n.mapChildren(tr)
}
// Remove all NominalTypes (which might have changed)
if(n3.hasType && hasNominalType(n3.nodeType)) n3.untyped else n3
}
tr(tree)
}
tr(tree)
val features = UsedFeatures(hasDistinct, hasTypeMapping, hasAggregate, hasNonPrimitiveOption)
logger.debug("Detected features: "+features)
s2 + (this -> features)
}
def hasNominalType(t: Type): Boolean = t match {
@@ -47,3 +73,5 @@ class AssignUniqueSymbols extends Phase {
case _ => t.children.exists(hasNominalType)
}
}
case class UsedFeatures(distinct: Boolean, typeMapping: Boolean, aggregate: Boolean, nonPrimitiveOption: Boolean)
@@ -10,51 +10,55 @@ import slick.util.{ConstArray, Ellipsis, ??}
class CreateAggregates extends Phase {
val name = "createAggregates"
def apply(state: CompilerState) = state.map(_.replace({
case n @ Apply(f: AggregateFunctionSymbol, ConstArray(from)) =>
logger.debug("Converting aggregation function application", n)
val CollectionType(_, elType @ Type.Structural(StructType(els))) = from.nodeType
val s = new AnonSymbol
val a = Aggregate(s, from, Apply(f, ConstArray(f match {
case Library.CountAll => LiteralNode(1)
case _ => Select(Ref(s) :@ elType, els.head._1) :@ els.head._2
}))(n.nodeType)).infer()
logger.debug("Converted aggregation function application", a)
inlineMap(a)
def apply(state: CompilerState) = {
if(state.get(Phase.assignUniqueSymbols).map(_.aggregate).getOrElse(true))
state.map(_.replace({
case n @ Apply(f: AggregateFunctionSymbol, ConstArray(from)) =>
logger.debug("Converting aggregation function application", n)
val CollectionType(_, elType @ Type.Structural(StructType(els))) = from.nodeType
val s = new AnonSymbol
val a = Aggregate(s, from, Apply(f, ConstArray(f match {
case Library.CountAll => LiteralNode(1)
case _ => Select(Ref(s) :@ elType, els.head._1) :@ els.head._2
}))(n.nodeType)).infer()
logger.debug("Converted aggregation function application", a)
inlineMap(a)
case n @ Bind(s1, from1, Pure(sel1, ts1)) if !from1.isInstanceOf[GroupBy] =>
val (sel2, temp) = liftAggregates(sel1, s1)
if(temp.isEmpty) n else {
logger.debug("Lifting aggregates into join in:", n)
logger.debug("New mapping with temporary refs:", sel2)
val sources = (from1 match {
case Pure(StructNode(ConstArray()), _) => Vector.empty[(TermSymbol, Node)]
case _ => Vector(s1 -> from1)
}) ++ temp.map { case (s, n) => (s, Pure(n)) }
val from2 = sources.init.foldRight(sources.last._2) {
case ((_, n), z) => Join(new AnonSymbol, new AnonSymbol, n, z, JoinType.Inner, LiteralNode(true))
}.infer()
logger.debug("New 'from' with joined aggregates:", from2)
val repl: Map[TermSymbol, List[TermSymbol]] = sources match {
case Vector((s, n)) => Map(s -> List(s1))
case _ =>
val len = sources.length
val it = Iterator.iterate(s1)(_ => ElementSymbol(2))
sources.zipWithIndex.map { case ((s, _), i) =>
val l = List.iterate(s1, i+1)(_ => ElementSymbol(2))
s -> (if(i == len-1) l else l :+ ElementSymbol(1))
}.toMap
}
logger.debug("Replacement paths: " + repl)
val scope = Type.Scope(s1 -> from2.nodeType.asCollectionType.elementType)
val replNodes = repl.mapValues(ss => FwdPath(ss).infer(scope))
logger.debug("Replacement path nodes: ", StructNode(ConstArray.from(replNodes)))
val sel3 = sel2.replace({ case n @ Ref(s) => replNodes.getOrElse(s, n) }, keepType = true)
val n2 = Bind(s1, from2, Pure(sel3, ts1)).infer()
logger.debug("Lifted aggregates into join in:", n2)
n2
}
}, keepType = true, bottomUp = true))
case n @ Bind(s1, from1, Pure(sel1, ts1)) if !from1.isInstanceOf[GroupBy] =>
val (sel2, temp) = liftAggregates(sel1, s1)
if(temp.isEmpty) n else {
logger.debug("Lifting aggregates into join in:", n)
logger.debug("New mapping with temporary refs:", sel2)
val sources = (from1 match {
case Pure(StructNode(ConstArray()), _) => Vector.empty[(TermSymbol, Node)]
case _ => Vector(s1 -> from1)
}) ++ temp.map { case (s, n) => (s, Pure(n)) }
val from2 = sources.init.foldRight(sources.last._2) {
case ((_, n), z) => Join(new AnonSymbol, new AnonSymbol, n, z, JoinType.Inner, LiteralNode(true))
}.infer()
logger.debug("New 'from' with joined aggregates:", from2)
val repl: Map[TermSymbol, List[TermSymbol]] = sources match {
case Vector((s, n)) => Map(s -> List(s1))
case _ =>
val len = sources.length
val it = Iterator.iterate(s1)(_ => ElementSymbol(2))
sources.zipWithIndex.map { case ((s, _), i) =>
val l = List.iterate(s1, i+1)(_ => ElementSymbol(2))
s -> (if(i == len-1) l else l :+ ElementSymbol(1))
}.toMap
}
logger.debug("Replacement paths: " + repl)
val scope = Type.Scope(s1 -> from2.nodeType.asCollectionType.elementType)
val replNodes = repl.mapValues(ss => FwdPath(ss).infer(scope))
logger.debug("Replacement path nodes: ", StructNode(ConstArray.from(replNodes)))
val sel3 = sel2.replace({ case n @ Ref(s) => replNodes.getOrElse(s, n) }, keepType = true)
val n2 = Bind(s1, from2, Pure(sel3, ts1)).infer()
logger.debug("Lifted aggregates into join in:", n2)
n2
}
}, keepType = true, bottomUp = true))
else state
}
/** Recursively inline mapping Bind calls under an Aggregate */
def inlineMap(a: Aggregate): Aggregate = a.from match {
@@ -73,7 +73,9 @@ class RemoveMappedTypes extends Phase {
type State = Type
def apply(state: CompilerState) =
state.withNode(removeTypeMapping(state.tree)) + (this -> state.tree.nodeType)
if(state.get(Phase.assignUniqueSymbols).map(_.typeMapping).getOrElse(true))
state.withNode(removeTypeMapping(state.tree)) + (this -> state.tree.nodeType)
else state + (this -> state.tree.nodeType)
/** Remove TypeMapping nodes and MappedTypes */
def removeTypeMapping(n: Node): Node = n match {
@@ -12,7 +12,10 @@ import scala.collection.mutable
class ExpandConditionals extends Phase {
val name = "expandConditionals"
def apply(state: CompilerState) = state.map(expand)
def apply(state: CompilerState) = {
if(state.get(Phase.expandSums).getOrElse(true)) state.map(expand)
else state
}
def expand(n: Node): Node = {
val invalid = mutable.HashSet.empty[TypeSymbol]
@@ -8,27 +8,41 @@ import TypeUtil._
import scala.collection.mutable
/** Expand sum types and their catamorphisms to equivalent product type operations. */
/** Expand sum types and their catamorphisms to equivalent product type operations.
* The phase state is a flag indicating whether there is anything left to clean up in
* `expandConditionals`. */
class ExpandSums extends Phase {
val name = "expandSums"
def apply(state: CompilerState) = state.map(n => tr(n, Set.empty))
type State = Boolean
def apply(state: CompilerState) = {
if(state.get(Phase.assignUniqueSymbols).map(_.nonPrimitiveOption).getOrElse(true)) {
val (n, multi) = tr(state.tree, Set.empty)
state.withNode(n) + (this -> multi)
} else state + (this -> false)
}
val Disc1 = LiteralNode(ScalaBaseType.optionDiscType.optionType, Option(1))
val DiscNone = LiteralNode(ScalaBaseType.optionDiscType.optionType, None)
/** Perform the sum expansion on a Node */
def tr(tree: Node, oldDiscCandidates: Set[(TypeSymbol, List[TermSymbol])]): Node = {
def tr(tree: Node, oldDiscCandidates: Set[(TypeSymbol, List[TermSymbol])]): (Node, Boolean) = {
val discCandidates = oldDiscCandidates ++ (tree match {
case Filter(_, _, p) => collectDiscriminatorCandidates(p)
case Bind(_, j: Join, _) => collectDiscriminatorCandidates(j.on)
case _ => Set.empty
})
val tree2 = tree.mapChildren(n => tr(n, discCandidates), keepType = true)
var multi = false
val tree2 = tree.mapChildren({ n =>
val (n2, flag) = tr(n, discCandidates)
multi |= flag
n2
}, keepType = true)
val tree3 = tree2 match {
// Expand multi-column null values in ELSE branches (used by Rep[Option].filter) with correct type
case IfThenElse(ConstArray(pred, then1 :@ tpe, LiteralNode(None) :@ OptionType(ScalaBaseType.nullType))) =>
multi = true
IfThenElse(ConstArray(pred, then1, buildMultiColumnNone(tpe))) :@ tpe
// Primitive OptionFold representing GetOrElse -> translate to GetOrElse
@@ -52,6 +66,7 @@ class ExpandSums extends Phase {
// Other OptionFold -> translate to discriminator check
case OptionFold(from, ifEmpty, map, gen) =>
multi = true
val left = from.select(ElementSymbol(1)).infer()
val pred = Library.==.typed[Boolean](left, LiteralNode(null))
val n2 = (ifEmpty, map) match {
@@ -70,7 +85,9 @@ class ExpandSums extends Phase {
case n @ OptionApply(_) :@ OptionType.Primitive(_) => n
// Other OptionApply -> translate to product form
case n @ OptionApply(ch) => ProductNode(ConstArray(Disc1, silentCast(toOptionColumns(ch.nodeType), ch))).infer()
case n @ OptionApply(ch) =>
multi = true
ProductNode(ConstArray(Disc1, silentCast(toOptionColumns(ch.nodeType), ch))).infer()
// Non-primitive GetOrElse
// (.get is only defined on primitive Options, but this can occur inside of HOFs like .map)
@@ -82,12 +99,13 @@ class ExpandSums extends Phase {
// Option-extended left outer, right outer or full outer join
case bind @ Bind(bsym, Join(_, _, _, _, jt, _), _) if jt == JoinType.LeftOption || jt == JoinType.RightOption || jt == JoinType.OuterOption =>
multi = true
translateJoin(bind, discCandidates)
case n => n
}
val tree4 = fuse(tree3)
tree4 :@ trType(tree4.nodeType)
(tree4 :@ trType(tree4.nodeType), multi)
}
/** Translate an Option-extended left outer, right outer or full outer join */
@@ -10,7 +10,7 @@ import slick.util.{Ellipsis, ConstArray}
class RewriteDistinct extends Phase {
val name = "rewriteDistinct"
def apply(state: CompilerState) = state.map(_.replace({
def apply(state: CompilerState) = if(state.get(Phase.assignUniqueSymbols).map(_.distinct).getOrElse(true)) state.map(_.replace({
case n @ Bind(s1, Distinct(s2, from1, on1), Pure(sel1, ts1)) =>
logger.debug("Rewriting Distinct:", Ellipsis(n, List(0, 0)))
@@ -54,5 +54,5 @@ class RewriteDistinct extends Phase {
ret
}
}, keepType = true, bottomUp = true))
}, keepType = true, bottomUp = true)) else state
}

0 comments on commit d286c17

Please sign in to comment.