Permalink
Browse files

Some small query compiler performance optimizations:

- Use IndexedSeq instead of Seq everywhere in the AST to avoid
  conversions.

- More efficient `Node.mapChildren` implementation.

- Add PathElement as a common abstraction for Select and Ref.

- More efficient `assignUniqueSymbols` implementation.

- Avoid unnecessary retyping in several phases.

This brings the performance back up to 3.0 levels.
  • Loading branch information...
szeiger committed Aug 14, 2015
1 parent a71aff4 commit 3bd24cc89849c32dabfd159dcc6ee3efc91f8097
@@ -73,7 +73,7 @@ final case class ResultSetMapping(generator: TermSymbol, from: Node, map: Node)
/** A switch for special-cased parameters that needs to be interpreted in order
* to find the correct query string for the query arguments. */
final case class ParameterSwitch(cases: Seq[((Any => Boolean), Node)], default: Node) extends SimplyTypedNode with ClientSideOp {
final case class ParameterSwitch(cases: IndexedSeq[((Any => Boolean), Node)], default: Node) extends SimplyTypedNode with ClientSideOp {
type Self = ParameterSwitch
def children = cases.map(_._2) :+ default
override def childNames = cases.map("[" + _._1 + "]") :+ "default"
@@ -5,11 +5,11 @@ import Util._
/** A SQL comprehension */
final case class Comprehension(sym: TermSymbol, from: Node, select: Node, where: Option[Node] = None,
groupBy: Option[Node] = None, orderBy: Seq[(Node, Ordering)] = Seq.empty,
groupBy: Option[Node] = None, orderBy: IndexedSeq[(Node, Ordering)] = Vector.empty,
having: Option[Node] = None,
fetch: Option[Node] = None, offset: Option[Node] = None) extends DefNode {
type Self = Comprehension
val children = Seq(from, select) ++ where ++ groupBy ++ orderBy.map(_._1) ++ having ++ fetch ++ offset
val children = Vector(from, select) ++ where ++ groupBy ++ orderBy.map(_._1) ++ having ++ fetch ++ offset
override def childNames =
Seq("from "+sym, "select") ++
where.map(_ => "where") ++
@@ -81,7 +81,7 @@ final case class Comprehension(sym: TermSymbol, from: Node, select: Node, where:
}
/** The row_number window function */
final case class RowNumber(by: Seq[(Node, Ordering)] = Seq.empty) extends SimplyTypedNode {
final case class RowNumber(by: IndexedSeq[(Node, Ordering)] = Vector.empty) extends SimplyTypedNode {
type Self = RowNumber
def buildType = ScalaBaseType.longType
lazy val children = by.map(_._1)
@@ -109,10 +109,10 @@ class FunctionSymbol(val name: String) extends TermSymbol {
}
/** Create a typed Apply of this Symbol */
def typed(tpe: Type, ch: Node*): Apply = Apply(this, ch)(tpe)
def typed(tpe: Type, ch: Node*): Apply = Apply(this, ch.toIndexedSeq)(tpe)
/** Create a typed Apply of this Symbol */
def typed[T : ScalaBaseType](ch: Node*): Apply = Apply(this, ch)(implicitly[ScalaBaseType[T]])
def typed[T : ScalaBaseType](ch: Node*): Apply = Apply(this, ch.toIndexedSeq)(implicitly[ScalaBaseType[T]])
override def toString = "Function "+name
}
@@ -16,7 +16,7 @@ trait Node extends Dumpable {
private var _type: Type = UnassignedType
/** All child nodes of this node. Must be implemented by subclasses. */
def children: Seq[Node]
def children: IndexedSeq[Node]
/** Names for the child nodes to show in AST dumps. Defaults to a numbered sequence starting at 0
* but can be overridden by subclasses to produce more suitable names. */
@@ -35,7 +35,7 @@ trait Node extends Dumpable {
* children. If all new children are identical to the old ones, this node is returned. If
* ``keepType`` is true, the type of this node is kept even when the children have changed. */
final def mapChildren(f: Node => Node, keepType: Boolean = false): Self = if(isInstanceOf[NullaryNode]) this else {
val n: Self = mapOrNone(children)(f).map(rebuild).getOrElse(this)
val n: Self = mapOrNone(children)(f).fold(this: Self)(rebuild)
if(_type == UnassignedType || !keepType) n else (n :@ _type).asInstanceOf[Self]
}
@@ -53,11 +53,11 @@ trait Node extends Dumpable {
/** Return this Node with no Type assigned (if it has not yet been observed) or an untyped copy. */
final def untyped: Self =
if(seenType || _type != UnassignedType) rebuild(children.toIndexedSeq) else this
if(seenType || _type != UnassignedType) rebuild(children) else this
/** Return this Node with a Type assigned (if no other type has been seen for it yet) or a typed copy. */
final def :@ (newType: Type): Self = {
val n: Self = if(seenType && newType != _type) rebuild(children.toIndexedSeq) else this
val n: Self = if(seenType && newType != _type) rebuild(children) else this
n._type = newType
n
}
@@ -106,7 +106,7 @@ trait SimplyTypedNode extends Node {
}
/** An expression that represents a conjunction of expressions. */
final case class ProductNode(children: Seq[Node]) extends SimplyTypedNode {
final case class ProductNode(children: IndexedSeq[Node]) extends SimplyTypedNode {
type Self = ProductNode
override def getDumpInfo = super.getDumpInfo.copy(name = "ProductNode", mainInfo = "")
protected[this] def rebuild(ch: IndexedSeq[Node]): Self = copy(ch)
@@ -115,11 +115,11 @@ final case class ProductNode(children: Seq[Node]) extends SimplyTypedNode {
val t = ch.nodeType
if(t == UnassignedType) throw new SlickException(s"ProductNode child $ch has UnassignedType")
t
}(collection.breakOut))
})
def flatten: ProductNode = {
def f(n: Node): IndexedSeq[Node] = n match {
case ProductNode(ns) => ns.flatMap(f).toIndexedSeq
case StructNode(els) => els.flatMap(el => f(el._2)).toIndexedSeq
case ProductNode(ns) => ns.flatMap(f)
case StructNode(els) => els.flatMap(el => f(el._2))
case n => IndexedSeq(n)
}
ProductNode(f(this))
@@ -175,20 +175,20 @@ object LiteralNode {
trait BinaryNode extends Node {
def left: Node
def right: Node
lazy val children = Seq(left, right)
lazy val children = Vector(left, right)
protected[this] final def rebuild(ch: IndexedSeq[Node]): Self = rebuild(ch(0), ch(1))
protected[this] def rebuild(left: Node, right: Node): Self
}
trait UnaryNode extends Node {
def child: Node
lazy val children = Seq(child)
lazy val children = Vector(child)
protected[this] final def rebuild(ch: IndexedSeq[Node]): Self = rebuild(ch(0))
protected[this] def rebuild(child: Node): Self
}
trait NullaryNode extends Node {
val children = Nil
val children = Vector.empty
protected[this] final def rebuild(ch: IndexedSeq[Node]): Self = rebuild
protected[this] def rebuild: Self
}
@@ -247,7 +247,7 @@ abstract class FilteredQuery extends Node {
val genScope = scope + (generator -> from2.nodeType.asCollectionType.elementType)
val ch2: IndexedSeq[Node] = children.map { ch =>
if(ch eq from) from2 else ch.infer(genScope, typeChildren)
}(collection.breakOut)
}
(withChildren(ch2) :@ (if(!hasType) ch2.head.nodeType else nodeType)).asInstanceOf[Self]
}
}
@@ -270,7 +270,7 @@ object Filter {
}
/** A .sortBy call of type (CollectionType(c, t), _) => CollectionType(c, t). */
final case class SortBy(generator: TermSymbol, from: Node, by: Seq[(Node, Ordering)]) extends FilteredQuery with DefNode {
final case class SortBy(generator: TermSymbol, from: Node, by: IndexedSeq[(Node, Ordering)]) extends FilteredQuery with DefNode {
type Self = SortBy
lazy val children = from +: by.map(_._1)
protected[this] def rebuild(ch: IndexedSeq[Node]) =
@@ -443,8 +443,13 @@ final case class TableExpansion(generator: TermSymbol, table: Node, columns: Nod
}
}
trait PathElement extends Node {
def sym: TermSymbol
}
/** An expression that selects a field in another expression. */
final case class Select(in: Node, field: TermSymbol) extends UnaryNode with SimplyTypedNode {
final case class Select(in: Node, field: TermSymbol) extends PathElement with UnaryNode with SimplyTypedNode {
def sym = field
type Self = Select
def child = in
override def childNames = Seq("in")
@@ -457,14 +462,14 @@ final case class Select(in: Node, field: TermSymbol) extends UnaryNode with Simp
}
/** A function call expression. */
final case class Apply(sym: TermSymbol, children: Seq[Node])(val buildType: Type) extends SimplyTypedNode {
final case class Apply(sym: TermSymbol, children: IndexedSeq[Node])(val buildType: Type) extends SimplyTypedNode {
type Self = Apply
protected[this] def rebuild(ch: IndexedSeq[slick.ast.Node]) = copy(children = ch)(buildType)
override def getDumpInfo = super.getDumpInfo.copy(mainInfo = sym.toString)
}
/** A reference to a Symbol */
final case class Ref(sym: TermSymbol) extends NullaryNode {
final case class Ref(sym: TermSymbol) extends PathElement with NullaryNode {
type Self = Ref
def withInferredType(scope: Type.Scope, typeChildren: Boolean): Self =
if(hasType) this else {
@@ -479,13 +484,13 @@ final case class Ref(sym: TermSymbol) extends NullaryNode {
/** A constructor/extractor for nested Selects starting at a Ref so that, for example,
* `c :: b :: a :: Nil` corresponds to path `a.b.c`. */
object Path {
def apply(l: List[TermSymbol]): Node = l match {
def apply(l: List[TermSymbol]): PathElement = l match {
case s :: Nil => Ref(s)
case s :: l => Select(apply(l), s)
}
def unapply(n: Node): Option[List[TermSymbol]] = n match {
def unapply(n: PathElement): Option[List[TermSymbol]] = n match {
case Ref(sym) => Some(List(sym))
case Select(in, s) => unapply(in).map(l => s :: l)
case Select(in: PathElement, s) => unapply(in).map(l => s :: l)
case _ => None
}
def toString(path: Seq[TermSymbol]): String = path.reverseIterator.mkString("Path ", ".", "")
@@ -499,7 +504,7 @@ object Path {
* `a :: b :: c :: Nil` corresponds to path `a.b.c`. */
object FwdPath {
def apply(ch: List[TermSymbol]) = Path(ch.reverse)
def unapply(n: Node): Option[List[TermSymbol]] = Path.unapply(n).map(_.reverse)
def unapply(n: PathElement): Option[List[TermSymbol]] = Path.unapply(n).map(_.reverse)
def toString(path: Seq[TermSymbol]): String = path.mkString("Path ", ".", "")
}
@@ -13,7 +13,7 @@ import slick.util.{DumpInfo, Dumpable, TupleSupport}
/** Super-trait for all types */
trait Type extends Dumpable {
/** All children of this Type. */
def children: Seq[Type]
def children: IndexedSeq[Type]
/** Apply a transformation to all type children and reconstruct this
* type with the new children, or return the original object if no
* child is changed. */
@@ -42,7 +42,7 @@ object Type {
/** An atomic type (i.e. a type which does not contain other types) */
trait AtomicType extends Type {
final def mapChildren(f: Type => Type): this.type = this
def children: Seq[Type] = Seq.empty
def children = Vector.empty
}
final case class StructType(elements: IndexedSeq[(TermSymbol, Type)]) extends Type {
@@ -65,7 +65,7 @@ final case class StructType(elements: IndexedSeq[(TermSymbol, Type)]) extends Ty
trait OptionType extends Type {
override def toString = "Option[" + elementType + "]"
def elementType: Type
def children: Seq[Type] = Seq(elementType)
def children: IndexedSeq[Type] = Vector(elementType)
def classTag = OptionType.classTag
override def hashCode = elementType.hashCode() + 100
override def equals(o: Any) = o match {
@@ -118,7 +118,7 @@ final case class ProductType(elements: IndexedSeq[Type]) extends Type {
case ElementSymbol(i) if i <= elements.length => elements(i-1)
case _ => super.select(sym)
}
def children: Seq[Type] = elements
def children: IndexedSeq[Type] = elements
def numberedElements: Iterator[(ElementSymbol, Type)] =
elements.iterator.zipWithIndex.map { case (t, i) => (new ElementSymbol(i+1), t) }
def classTag = TupleSupport.classTagForArity(elements.size)
@@ -131,7 +131,7 @@ final case class CollectionType(cons: CollectionTypeConstructor, elementType: Ty
if(e2 eq elementType) this
else CollectionType(cons, e2)
}
def children: Seq[Type] = Seq(elementType)
def children: IndexedSeq[Type] = Vector(elementType)
def classTag = cons.classTag
}
@@ -200,7 +200,7 @@ final class MappedScalaType(val baseType: Type, val mapper: MappedScalaType.Mapp
if(e2 eq baseType) this
else new MappedScalaType(e2, mapper, classTag)
}
def children: Seq[Type] = Seq(baseType)
def children: IndexedSeq[Type] = Vector(baseType)
override def select(sym: TermSymbol) = baseType.select(sym)
override def hashCode = baseType.hashCode() + mapper.hashCode() + classTag.hashCode()
override def equals(o: Any) = o match {
@@ -235,7 +235,7 @@ final case class NominalType(sym: TypeSymbol, structuralView: Type) extends Type
if(struct2 eq structuralView) this
else new NominalType(sym, struct2)
}
def children: Seq[Type] = Seq(structuralView)
def children: IndexedSeq[Type] = Vector(structuralView)
def sourceNominalType: NominalType = structuralView match {
case n: NominalType => n.sourceNominalType
case _ => this
@@ -70,15 +70,13 @@ final class NodeOps(val tree: Node) extends AnyVal {
* retyped afterwards to get the correct new TypeSymbols in. */
def replaceInvalidate(f: PartialFunction[(Node, Set[TypeSymbol], Node), (Node, Set[TypeSymbol])]): Node = {
replaceFold(Set.empty[TypeSymbol])(f.orElse {
case ((n: Ref), invalid, _) if containsTS(n.nodeType, invalid) => (n.untyped, invalid)
case ((n: Select), invalid, _) if containsTS(n.nodeType, invalid) => (n.untyped, invalid)
case ((n: PathElement), invalid, _) if containsTS(n.nodeType, invalid) => (n.untyped, invalid)
})._1
}
def untypeReferences(invalid: Set[TypeSymbol]): Node = {
if(invalid.isEmpty) tree else replace({
case n: Ref if containsTS(n.nodeType, invalid) => n.untyped
case n: Select if containsTS(n.nodeType, invalid) => n.untyped
case n: PathElement if containsTS(n.nodeType, invalid) => n.untyped
}, bottomUp = true)
}
@@ -16,23 +16,13 @@ class AssignUniqueSymbols extends Phase {
def apply(state: CompilerState) = state.map { tree =>
val seen = new HashSet[AnonSymbol]
val seenType = new HashSet[TypeSymbol]
def tr(n: Node, replace: Map[AnonSymbol, AnonSymbol]): Node = {
val n2 = n match { // Give TableNode and Pure nodes a unique TypeSymbol
case t: TableNode =>
t.copy(identity = new AnonTableIdentitySymbol)
case p @ Pure(value, ts) =>
if(seenType contains ts) Pure(value)
else {
seenType += ts
p
}
case t: TableNode => t.copy(identity = new AnonTableIdentitySymbol)
case Pure(value, _) => Pure(value)
case n => n
}
val n3 = // Remove all NominalTypes (which might have changed)
if(n2.hasType && !(n2.nodeType.collect { case _: NominalType => () }).isEmpty) n2.untyped
else n2
n3 match {
val n3 = n2 match {
case r @ Ref(a: AnonSymbol) => replace.get(a) match {
case Some(s) => if(s eq a) r else Ref(s)
case None => r
@@ -55,7 +45,14 @@ class AssignUniqueSymbols extends Phase {
case n: Select => n.mapChildren(tr(_, replace)) :@ n.nodeType
case n => n.mapChildren(tr(_, replace))
}
// Remove all NominalTypes (which might have changed)
if(n3.hasType && hasNominalType(n3.nodeType)) n3.untyped else n3
}
tr(tree, Map())
}
def hasNominalType(t: Type): Boolean = t match {
case _: NominalType => true
case _ => t.children.exists(hasNominalType)
}
}
@@ -10,14 +10,12 @@ import slick.util.{Ellipsis, ??}
class CreateAggregates extends Phase {
val name = "createAggregates"
def apply(state: CompilerState) = state.map(tr)
def tr(n: Node): Node = n.mapChildren(tr, keepType = true) match {
def apply(state: CompilerState) = state.map(_.replace({
case n @ Apply(f: AggregateFunctionSymbol, Seq(from)) :@ tpe =>
logger.debug("Converting aggregation function application", n)
val CollectionType(_, elType @ Type.Structural(StructType(els))) = from.nodeType
val s = new AnonSymbol
val a = Aggregate(s, from, Apply(f, Seq(f match {
val a = Aggregate(s, from, Apply(f, Vector(f match {
case Library.CountAll => LiteralNode(1)
case _ => Select(Ref(s) :@ elType, els.head._1) :@ els.head._2
}))(tpe)).infer()
@@ -56,9 +54,7 @@ class CreateAggregates extends Phase {
logger.debug("Lifted aggregates into join in:", n2)
n2
}
case n => n
}
}, keepType = true, bottomUp = true))
/** Recursively inline mapping Bind calls under an Aggregate */
def inlineMap(a: Aggregate): Aggregate = a.from match {
@@ -80,8 +76,7 @@ class CreateAggregates extends Phase {
def liftAggregates(n: Node, outer: TermSymbol): (Node, Map[TermSymbol, Aggregate]) = n match {
case a @ Aggregate(s1, f1, sel1) =>
if(a.findNode {
case Ref(s) => s == outer
case Select(_, s) => s == outer
case n: PathElement => n.sym == outer
case _ => false
}.isDefined) (a, Map.empty)
else {
@@ -32,14 +32,14 @@ class EmulateOuterJoins(val useLeftJoin: Boolean, val useRightJoin: Boolean) ext
Filter(lgen2, left,
Library.Not.typed(on.nodeType, Library.Exists.typed(on.nodeType, Filter(rgen2, right, on2)))
),
Pure(ProductNode(Seq(Ref(bgen), nullStructFor(right.nodeType.structural.asCollectionType.elementType))))
Pure(ProductNode(Vector(Ref(bgen), nullStructFor(right.nodeType.structural.asCollectionType.elementType))))
), true).infer())
case Join(leftGen, rightGen, left, right, JoinType.Right, on) if !useRightJoin =>
// as rightJoin bs on e => bs leftJoin as on { (b, a) => e(a, b) } map { case (b, a) => (a, b) }
val bgen = new AnonSymbol
convert(Bind(bgen,
Join(rightGen, leftGen, right, left, JoinType.Left, on),
Pure(ProductNode(Seq(Select(Ref(bgen), ElementSymbol(2)), Select(Ref(bgen), ElementSymbol(1)))))
Pure(ProductNode(Vector(Select(Ref(bgen), ElementSymbol(2)), Select(Ref(bgen), ElementSymbol(1)))))
).infer())
case Join(leftGen, rightGen, left, right, JoinType.Outer, on) =>
// as fullJoin bs on e => (as leftJoin bs on e) unionAll bs.filter(b => !exists(as.filter(a => e(a, b)))).map(b => (nulls, b))
@@ -53,7 +53,7 @@ class EmulateOuterJoins(val useLeftJoin: Boolean, val useRightJoin: Boolean) ext
Filter(rgen2, right,
Library.Not.typed(on.nodeType, Library.Exists.typed(on.nodeType, Filter(lgen2, left, on2)))
),
Pure(ProductNode(Seq(nullStructFor(left.nodeType.structural.asCollectionType.elementType), Ref(bgen))))
Pure(ProductNode(Vector(nullStructFor(left.nodeType.structural.asCollectionType.elementType), Ref(bgen))))
), true).infer())
case n => n.mapChildren(convert, true)
}
@@ -69,7 +69,7 @@ class ExpandConditionals extends Phase {
val n2 = tr(n)
logger.debug("Invalidated TypeSymbols: "+invalid.mkString(", "))
n2.replace({
case n @ (_: Ref | _: Select) :@ tpe if invalid.intersect(tpe.collect { case NominalType(ts, _) => ts }.toSet).nonEmpty =>
case (n: PathElement) :@ tpe if invalid.intersect(tpe.collect { case NominalType(ts, _) => ts }.toSet).nonEmpty =>
n.untyped
}, bottomUp = true).infer()
}
Oops, something went wrong.

0 comments on commit 3bd24cc

Please sign in to comment.