Skip to content

Commit

Permalink
More micro-optimizations and streamlining of ConstArray use
Browse files Browse the repository at this point in the history
  • Loading branch information
szeiger committed Aug 26, 2015
1 parent 2adb7c3 commit c9bdcb1
Show file tree
Hide file tree
Showing 14 changed files with 96 additions and 61 deletions.
26 changes: 21 additions & 5 deletions slick/src/main/scala/slick/ast/Node.scala
Expand Up @@ -53,6 +53,12 @@ trait Node extends Dumpable {
if(!keepType || (_type eq UnassignedType)) n else (n :@ _type).asInstanceOf[Self]
}

/** Apply a side-effecting function to all direct children from left to right. Note that
* {{{ n.childrenForeach(f) }}} is equivalent to {{{ n.children.foreach(f) }}} but can be
* implemented more efficiently in `Node` subclasses. */
def childrenForeach[R](f: Node => R): Unit =
children.foreach(f)

/** The current type of this node. */
def nodeType: Type = {
seenType = true
Expand Down Expand Up @@ -202,6 +208,10 @@ trait BinaryNode extends Node {
if(!keepType || (_type eq UnassignedType)) n else (n :@ _type).asInstanceOf[Self]
}
override final protected[this] def buildCopy: Self = rebuild(left, right)
override final def childrenForeach[R](f: Node => R): Unit = {
f(left)
f(right)
}
}

trait UnaryNode extends Node {
Expand All @@ -217,6 +227,7 @@ trait UnaryNode extends Node {
if(!keepType || (_type eq UnassignedType)) n else (n :@ _type).asInstanceOf[Self]
}
override final protected[this] def buildCopy: Self = rebuild(child)
override final def childrenForeach[R](f: Node => R): Unit = f(child)
}

trait NullaryNode extends Node {
Expand All @@ -225,6 +236,7 @@ trait NullaryNode extends Node {
protected[this] def rebuild: Self
override final def mapChildren(f: Node => Node, keepType: Boolean = false): Self = this
override final protected[this] def buildCopy: Self = rebuild
override final def childrenForeach[R](f: Node => R): Unit = ()
}

/** An expression that represents a plain value lifted into a Query. */
Expand Down Expand Up @@ -268,6 +280,7 @@ object Subquery {

/** Common superclass for expressions of type (CollectionType(c, t), _) => CollectionType(c, t). */
abstract class FilteredQuery extends Node {
type Self >: this.type <: FilteredQuery
protected[this] def generator: TermSymbol
def from: Node
def generators = ConstArray((generator, from))
Expand All @@ -279,10 +292,10 @@ abstract class FilteredQuery extends Node {
def withInferredType(scope: Type.Scope, typeChildren: Boolean): Self = {
val from2 = from.infer(scope, typeChildren)
val genScope = scope + (generator -> from2.nodeType.asCollectionType.elementType)
val ch2: ConstArray[Node] = children.map[Node] { ch =>
val this2 = mapChildren { ch =>
if(ch eq from) from2 else ch.infer(genScope, typeChildren)
}
(withChildren(ch2) :@ (if(!hasType) ch2.head.nodeType else nodeType)).asInstanceOf[Self]
(this2 :@ (if(!hasType) this2.from.nodeType else nodeType)).asInstanceOf[Self]
}
}

Expand Down Expand Up @@ -348,7 +361,8 @@ final case class GroupBy(fromGen: TermSymbol, from: Node, by: Node, identity: Ty
val from2 = from.infer(scope, typeChildren)
val from2Type = from2.nodeType.asCollectionType
val by2 = by.infer(scope + (fromGen -> from2Type.elementType), typeChildren)
withChildren(ConstArray[Node](from2, by2)) :@ (
val this2 = if((from2 eq from) && (by2 eq by)) this else copy(from = from2, by = by2)
this2 :@ (
if(!hasType)
CollectionType(from2Type.cons, ProductType(ConstArray(NominalType(identity, by2.nodeType), CollectionType(TypedCollectionTypeConstructor.seq, from2Type.elementType))))
else nodeType)
Expand Down Expand Up @@ -457,7 +471,8 @@ final case class Aggregate(sym: TermSymbol, from: Node, select: Node) extends Bi
def withInferredType(scope: Type.Scope, typeChildren: Boolean): Self = {
val from2 :@ CollectionType(_, el) = from.infer(scope, typeChildren)
val select2 = select.infer(scope + (sym -> el), typeChildren)
withChildren(ConstArray[Node](from2, select2)) :@ (if(!hasType) select2.nodeType else nodeType)
val this2 = if((from2 eq from) && (select2 eq select)) this else copy(from = from2, select = select2)
this2 :@ (if(!hasType) select2.nodeType else nodeType)
}
}

Expand All @@ -474,7 +489,8 @@ final case class TableExpansion(generator: TermSymbol, table: Node, columns: Nod
def withInferredType(scope: Type.Scope, typeChildren: Boolean): Self = {
val table2 = table.infer(scope, typeChildren)
val columns2 = columns.infer(scope + (generator -> table2.nodeType.asCollectionType.elementType), typeChildren)
withChildren(ConstArray[Node](table2, columns2)) :@ (if(!hasType) table2.nodeType else nodeType)
val this2 = if((table2 eq table) && (columns2 eq columns)) this else copy(table = table2, columns = columns2)
this2 :@ (if(!hasType) table2.nodeType else nodeType)
}
}

Expand Down
10 changes: 7 additions & 3 deletions slick/src/main/scala/slick/ast/Symbol.scala
Expand Up @@ -54,10 +54,14 @@ trait DefNode extends Node {
protected[this] def rebuildWithSymbols(gen: ConstArray[TermSymbol]): Node

final def mapScopedChildren(f: (Option[TermSymbol], Node) => Node): Self with DefNode = {
val all = (generators.iterator.map{ case (sym, n) => (Some(sym), n) } ++
children.iterator.drop(generators.length).map{ n => (None, n) }).toIndexedSeq
val gens = generators
val ch = children
val all = ch.zipWithIndex.map[(Option[TermSymbol], Node)] { case (ch, idx) =>
val o = if(idx < gens.length) Some(gens(idx)._1) else None
(o, ch)
}
val mapped = all.map(f.tupled)
if((all, mapped).zipped.map((a, m) => a._2 eq m).contains(false)) rebuild(ConstArray.from(mapped)).asInstanceOf[Self with DefNode]
if(ch.zip(mapped).force.exists { case (n1, n2) => n1 ne n2 }) rebuild(mapped).asInstanceOf[Self with DefNode]
else this
}
final def mapSymbols(f: TermSymbol => TermSymbol): Node = {
Expand Down
40 changes: 26 additions & 14 deletions slick/src/main/scala/slick/ast/Type.scala
Expand Up @@ -18,6 +18,9 @@ trait Type extends Dumpable {
* type with the new children, or return the original object if no
* child is changed. */
def mapChildren(f: Type => Type): Type
/** Apply a side-effecting function to all children. */
def childrenForeach[R](f: Type => R): Unit =
children.foreach(f)
def select(sym: TermSymbol): Type = throw new SlickException(s"No type for symbol $sym found in $this")
/** The structural view of this type */
def structural: Type = this
Expand All @@ -26,7 +29,7 @@ trait Type extends Dumpable {
/** A ClassTag for the erased type of this type's Scala values */
def classTag: ClassTag[_]
def getDumpInfo = DumpInfo(DumpInfo.simpleNameFor(getClass), toString, "",
children.toSeq.zipWithIndex.map { case (ch, i) => (i.toString, ch) })
children.zipWithIndex.map { case (ch, i) => (i.toString, ch) }.toSeq)
}

object Type {
Expand All @@ -43,12 +46,13 @@ object Type {
trait AtomicType extends Type {
final def mapChildren(f: Type => Type): this.type = this
def children = ConstArray.empty
override final def childrenForeach[R](f: Type => R): Unit = ()
}

final case class StructType(elements: ConstArray[(TermSymbol, Type)]) extends Type {
override def toString = "{" + elements.iterator.map{ case (s, t) => s + ": " + t }.mkString(", ") + "}"
lazy val symbolToIndex: Map[TermSymbol, Int] =
elements.iterator.zipWithIndex.map { case ((sym, _), idx) => (sym, idx) }.toMap
elements.zipWithIndex.map { case ((sym, _), idx) => (sym, idx) }.toMap
def children: ConstArray[Type] = elements.map(_._2)
def mapChildren(f: Type => Type): StructType = {
val ch = elements.map(_._2)
Expand All @@ -57,9 +61,12 @@ final case class StructType(elements: ConstArray[(TermSymbol, Type)]) extends Ty
}
override def select(sym: TermSymbol) = sym match {
case ElementSymbol(idx) => elements(idx-1)._2
case _ => elements.find(x => x._1 == sym).map(_._2).getOrElse(super.select(sym))
case _ =>
val i = elements.indexWhere(_._1 == sym)
if(i >= 0) elements(i)._2 else super.select(sym)
}
def classTag = TupleSupport.classTagForArity(elements.length)
override final def childrenForeach[R](f: Type => R): Unit = elements.foreach(t => f(t._2))
}

trait OptionType extends Type {
Expand All @@ -72,6 +79,7 @@ trait OptionType extends Type {
case OptionType(elem) if elementType == elem => true
case _ => false
}
override final def childrenForeach[R](f: Type => R): Unit = f(elementType)
}

object OptionType {
Expand All @@ -93,15 +101,15 @@ object OptionType {
/** An extractor for a non-nested Option type of a single column */
object Primitive {
def unapply(tpe: Type): Option[Type] = tpe.structural match {
case o: OptionType if o.elementType.structural.children.isEmpty => Some(o.elementType)
case o: OptionType if o.elementType.structural.isInstanceOf[AtomicType] => Some(o.elementType)
case _ => None
}
}

/** An extractor for a nested or multi-column Option type */
object NonPrimitive {
def unapply(tpe: Type): Option[Type] = tpe.structural match {
case o: OptionType if o.elementType.structural.children.nonEmpty => Some(o.elementType)
case o: OptionType if !o.elementType.structural.isInstanceOf[AtomicType] => Some(o.elementType)
case _ => None
}
}
Expand All @@ -118,8 +126,6 @@ final case class ProductType(elements: ConstArray[Type]) extends Type {
case _ => super.select(sym)
}
def children: ConstArray[Type] = elements
def numberedElements: Iterator[(ElementSymbol, Type)] =
elements.iterator.zipWithIndex.map { case (t, i) => (new ElementSymbol(i+1), t) }
def classTag = TupleSupport.classTagForArity(elements.length)
}

Expand All @@ -130,6 +136,7 @@ final case class CollectionType(cons: CollectionTypeConstructor, elementType: Ty
if(e2 eq elementType) this
else CollectionType(cons, e2)
}
override final def childrenForeach[R](f: Type => R): Unit = f(elementType)
def children: ConstArray[Type] = ConstArray(elementType)
def classTag = cons.classTag
}
Expand Down Expand Up @@ -199,6 +206,7 @@ final class MappedScalaType(val baseType: Type, val mapper: MappedScalaType.Mapp
if(e2 eq baseType) this
else new MappedScalaType(e2, mapper, classTag)
}
override final def childrenForeach[R](f: Type => R): Unit = f(baseType)
def children: ConstArray[Type] = ConstArray(baseType)
override def select(sym: TermSymbol) = baseType.select(sym)
override def hashCode = baseType.hashCode() + mapper.hashCode() + classTag.hashCode()
Expand Down Expand Up @@ -234,6 +242,7 @@ final case class NominalType(sym: TypeSymbol, structuralView: Type) extends Type
if(struct2 eq structuralView) this
else new NominalType(sym, struct2)
}
override final def childrenForeach[R](f: Type => R): Unit = f(structuralView)
def children: ConstArray[Type] = ConstArray(structuralView)
def sourceNominalType: NominalType = structuralView match {
case n: NominalType => n.sourceNominalType
Expand Down Expand Up @@ -284,19 +293,22 @@ class TypeUtil(val tpe: Type) extends AnyVal {
def replace(f: PartialFunction[Type, Type]): Type =
f.applyOrElse(tpe, { case t: Type => t.mapChildren(_.replace(f)) }: PartialFunction[Type, Type])

def collect[T](pf: PartialFunction[Type, T]): Iterable[T] = {
val b = new ArrayBuffer[T]
def g(n: Type) {
pf.andThen[Unit]{ case t => b += t }.orElse[Type, Unit]{ case _ => () }.apply(n)
n.children.foreach(g)
def collect[T](pf: PartialFunction[Type, T]): ConstArray[T] = {
val retNull: (Type => T) = (_ => null.asInstanceOf[T])
val b = ConstArray.newBuilder[T]()
def f(n: Type): Unit = {
val r = pf.applyOrElse(n, retNull)
if(r.asInstanceOf[AnyRef] ne null) b += r
n.childrenForeach(f)
}
g(tpe)
b
f(tpe)
b.result
}

def containsSymbol(tss: scala.collection.Set[TypeSymbol]): Boolean = {
if(tss.isEmpty) false else tpe match {
case NominalType(ts, exp) => tss.contains(ts) || exp.containsSymbol(tss)
case t: AtomicType => false
case t => t.children.exists(_.containsSymbol(tss))
}
}
Expand Down
48 changes: 25 additions & 23 deletions slick/src/main/scala/slick/ast/Util.scala
@@ -1,7 +1,7 @@
package slick.ast

import slick.ast.TypeUtil.:@
import slick.util.{ConstArray, ConstArrayBuilder}
import slick.util.ConstArray

import scala.collection
import scala.collection.mutable
Expand Down Expand Up @@ -29,30 +29,41 @@ final class NodeOps(val tree: Node) extends AnyVal {
import Util._
import NodeOps._

@inline def collect[T](pf: PartialFunction[Node, T], stopOnMatch: Boolean = false): ConstArray[T] = {
val b = new ConstArrayBuilder[T]
def f(n: Node): Unit = pf.andThen[Unit] { case t =>
b += t
if(!stopOnMatch) n.children.foreach(f)
}.orElse[Node, Unit]{ case _ =>
n.children.foreach(f)
}.apply(n)
def collect[T](pf: PartialFunction[Node, T], stopOnMatch: Boolean = false): ConstArray[T] = {
val retNull: (Node => T) = (_ => null.asInstanceOf[T])
val b = ConstArray.newBuilder[T]()
def f(n: Node): Unit = {
val r = pf.applyOrElse(n, retNull)
if(r.asInstanceOf[AnyRef] ne null) {
b += r
if(!stopOnMatch) n.childrenForeach(f)
}
else n.childrenForeach(f)
}
f(tree)
b.result
}

def collectAll[T](pf: PartialFunction[Node, ConstArray[T]]): ConstArray[T] = collect[ConstArray[T]](pf).flatten

def replace(f: PartialFunction[Node, Node], keepType: Boolean = false, bottomUp: Boolean = false): Node = {
def g(n: Node): Node = n.mapChildren(_.replace(f, keepType, bottomUp), keepType)
if(bottomUp) f.applyOrElse(g(tree), identity[Node]) else f.applyOrElse(tree, g)
if(bottomUp) {
def r(n: Node): Node = f.applyOrElse(g(n), identity[Node])
def g(n: Node): Node = n.mapChildren(r, keepType)
r(tree)
} else {
def r(n: Node): Node = f.applyOrElse(n, g)
def g(n: Node): Node = n.mapChildren(r, keepType)
r(tree)
}
}

/** Replace nodes in a bottom-up traversal while invalidating TypeSymbols. Any later references
* to the invalidated TypeSymbols have their types unassigned, so that the whole tree can be
* retyped afterwards to get the correct new TypeSymbols in. The PartialFunction may return
* `null`, which is considered the same as not matching. */
def replaceInvalidate(f: PartialFunction[Node, (Node, TypeSymbol)]): Node = {
import TypeUtil.typeToTypeUtil
val invalid = mutable.HashSet.empty[TypeSymbol]
val default = (_: Node) => null
def tr(n: Node): Node = {
Expand All @@ -62,16 +73,17 @@ final class NodeOps(val tree: Node) extends AnyVal {
invalid += res._2
res._1
} else n2 match {
case n2: PathElement if containsTS(n2.nodeType, invalid) => n2.untyped
case n2: PathElement if n2.nodeType.containsSymbol(invalid) => n2.untyped
case _ => n2
}
}
tr(tree)
}

def untypeReferences(invalid: Set[TypeSymbol]): Node = {
import TypeUtil.typeToTypeUtil
if(invalid.isEmpty) tree else replace({
case n: PathElement if containsTS(n.nodeType, invalid) => n.untyped
case n: PathElement if n.nodeType.containsSymbol(invalid) => n.untyped
}, bottomUp = true)
}

Expand All @@ -90,13 +102,3 @@ final class NodeOps(val tree: Node) extends AnyVal {
case (s, n) => Select(n, s)
}
}

private object NodeOps {
private def containsTS(t: Type, invalid: collection.Set[TypeSymbol]): Boolean = {
if(invalid.isEmpty) false else t match {
case NominalType(ts, exp) => invalid.contains(ts) || containsTS(exp, invalid)
case t: AtomicType => false
case t => t.children.exists(ch => containsTS(ch, invalid))
}
}
}
Expand Up @@ -53,6 +53,7 @@ class AssignUniqueSymbols extends Phase {

def hasNominalType(t: Type): Boolean = t match {
case _: NominalType => true
case _: AtomicType => false
case _ => t.children.exists(hasNominalType)
}
}
Expand Up @@ -48,7 +48,7 @@ class CreateResultSetMapping extends Phase {
ProductNode(ch.map { case (_, t) => f(t) })
case t: MappedScalaType =>
TypeMapping(f(t.baseType), t.mapper, t.classTag)
case o @ OptionType(Type.Structural(el)) if el.children.nonEmpty =>
case o @ OptionType(Type.Structural(el)) if !el.isInstanceOf[AtomicType] =>
val discriminator = Select(ref, syms(curIdx)).infer()
curIdx += 1
val data = f(o.elementType)
Expand Down
Expand Up @@ -16,7 +16,7 @@ class ExpandConditionals extends Phase {

def expand(n: Node): Node = {
val invalid = mutable.HashSet.empty[TypeSymbol]
def invalidate(n: Node): Unit = invalid ++= n.nodeType.collect { case NominalType(ts, _) => ts }
def invalidate(n: Node): Unit = invalid ++= n.nodeType.collect { case NominalType(ts, _) => ts }.toSeq

def tr(n: Node): Node = n.mapChildren(tr, keepType = true) match {
// Expand multi-column SilentCasts
Expand Down
10 changes: 5 additions & 5 deletions slick/src/main/scala/slick/compiler/ExpandSums.scala
Expand Up @@ -94,8 +94,8 @@ class ExpandSums extends Phase {
def translateJoin(bind: Bind, discCandidates: Set[(TypeSymbol, List[TermSymbol])]): Bind = {
logger.debug("translateJoin", bind)
val Bind(bsym, (join @ Join(lsym, rsym, left :@ CollectionType(_, leftElemType), right :@ CollectionType(_, rightElemType), jt, on)) :@ CollectionType(cons, elemType), pure) = bind
val lComplex = leftElemType.structural.children.nonEmpty
val rComplex = rightElemType.structural.children.nonEmpty
val lComplex = !leftElemType.structural.isInstanceOf[AtomicType]
val rComplex = !rightElemType.structural.isInstanceOf[AtomicType]
logger.debug(s"Translating join ($jt, complex: $lComplex, $rComplex):", bind)

// Find an existing column that can serve as a discriminator
Expand All @@ -110,7 +110,7 @@ class ExpandSums extends Phase {
}
def find(t: Type, path: List[TermSymbol]): Vector[List[TermSymbol]] = t.structural match {
case StructType(defs) => defs.toSeq.flatMap { case (s, t) => find(t, s :: path) }(collection.breakOut)
case p: ProductType => p.numberedElements.flatMap { case (s, t) => find(t, s :: path) }.toVector
case p: ProductType => p.elements.iterator.zipWithIndex.flatMap { case (t, i) => find(t, ElementSymbol(i+1) :: path) }.toVector
case _: AtomicType => Vector(path)
case _ => Vector.empty
}
Expand Down Expand Up @@ -216,8 +216,8 @@ class ExpandSums extends Phase {
/** Strip nominal types and convert all atomic types to OptionTypes */
def toOptionColumns(tpe: Type): Type = tpe match {
case NominalType(_, str) => toOptionColumns(str)
case o @ OptionType(ch) if ch.structural.children.isEmpty => o
case t if t.children.isEmpty => OptionType(t)
case o @ OptionType(ch) if ch.structural.isInstanceOf[AtomicType] => o
case t: AtomicType => OptionType(t)
case t => t.mapChildren(toOptionColumns)
}

Expand Down

0 comments on commit c9bdcb1

Please sign in to comment.