Skip to content

Commit

Permalink
Compiler improvements:
Browse files Browse the repository at this point in the history
- Move structural type reconstruction down from inferTypes into
  expandTables where a full tree transformation has to be done anyway.

- Remove the assignTypes phase. Types are now preserved everywhere.

- Some simplifications.
  • Loading branch information
szeiger committed Jun 22, 2015
1 parent 8baa079 commit 3470fa0
Show file tree
Hide file tree
Showing 16 changed files with 134 additions and 156 deletions.
Expand Up @@ -19,18 +19,18 @@ class PagingTest extends AsyncTest[RelationalTestDB] {
def q5 = q1 take 5 drop 3
val q6 = q1 take 0

seq(
ids.schema.create,
ids ++= (1 to 10),
q1.result.map(_ shouldBe (1 to 10).toList),
q2.result.map(_ shouldBe (1 to 5).toList),
ifCap(rcap.pagingDrop)(seq(
for {
_ <- ids.schema.create
_ <- ids ++= (1 to 10)
_ <- mark("q1", q1.result).map(_ shouldBe (1 to 10).toList)
_ <- mark("q2", q2.result).map(_ shouldBe (1 to 5).toList)
_ <- ifCap(rcap.pagingDrop)(seq(
q3.result.map(_ shouldBe (6 to 10).toList),
q4.result.map(_ shouldBe (6 to 8).toList),
q5.result.map(_ shouldBe (4 to 5).toList)
)),
q6.result.map(_ shouldBe Nil)
)
))
_ <- mark("q6", q6.result).map(_ shouldBe Nil)
} yield ()
}

def testCompiledPagination = {
Expand Down
2 changes: 1 addition & 1 deletion slick/src/main/scala/slick/ast/Node.scala
Expand Up @@ -567,7 +567,7 @@ object FwdPath {
/** A Node representing a database table. */
final case class TableNode(schemaName: Option[String], tableName: String, identity: TableIdentitySymbol, driverTable: Any, baseIdentity: TableIdentitySymbol) extends NullaryNode with TypedNode {
type Self = TableNode
def tpe = CollectionType(TypedCollectionTypeConstructor.seq, NominalType(identity, UnassignedStructuralType(identity)))
def tpe = CollectionType(TypedCollectionTypeConstructor.seq, NominalType(identity, UnassignedType))
def nodeRebuild = copy()
override def getDumpInfo = super.getDumpInfo.copy(name = "Table", mainInfo = schemaName.map(_ + ".").getOrElse("") + tableName)
}
Expand Down
6 changes: 0 additions & 6 deletions slick/src/main/scala/slick/ast/Type.scala
Expand Up @@ -205,12 +205,6 @@ case object UnassignedType extends AtomicType {
def classTag = throw new SlickException("UnassignedType does not have a ClassTag")
}

/** The type of a structural view of a NominalType before computing the
* proper type in the `inferTypes` phase. */
final case class UnassignedStructuralType(sym: TypeSymbol) extends AtomicType {
def classTag = throw new SlickException("UnassignedStructuralType does not have a ClassTag")
}

/** A type with a name, as used by tables.
*
* Compiler phases which change types may keep their own representation
Expand Down
13 changes: 0 additions & 13 deletions slick/src/main/scala/slick/compiler/AssignTypes.scala

This file was deleted.

25 changes: 25 additions & 0 deletions slick/src/main/scala/slick/compiler/ExpandRecords.scala
@@ -0,0 +1,25 @@
package slick.compiler

import slick.ast._
import Util._

/** Expand paths of record types to reference all fields individually and
* recreate the record structure at the call site. */
class ExpandRecords extends Phase {
val name = "expandRecords"

def apply(state: CompilerState) =
state.map(_.replace({ case n @ Path(_) => expandPath(n) }, keepType = true))

def expandPath(n: Node): Node = n.nodeType.structural match {
case StructType(ch) =>
StructNode(ch.map { case (s, t) =>
(s, expandPath(n.select(s).nodeTypedOrCopy(t)))
}(collection.breakOut)).nodeTyped(n.nodeType)
case p: ProductType =>
ProductNode(p.numberedElements.map { case (s, t) =>
expandPath(n.select(s).nodeTypedOrCopy(t))
}.toVector).nodeTyped(n.nodeType)
case t => n
}
}
58 changes: 58 additions & 0 deletions slick/src/main/scala/slick/compiler/ExpandTables.scala
@@ -0,0 +1,58 @@
package slick.compiler

import slick.ast._
import Util._
import TypeUtil._

/** Expand table-valued expressions in the result type to their star projection and compute the
* missing structural expansions of table types. */
class ExpandTables extends Phase {
val name = "expandTables"

def apply(state: CompilerState) = state.map { n => ClientSideOp.mapServerSide(n) { tree =>
// Find table fields
val structs = tree.collect[(TypeSymbol, (Symbol, Type))] {
case s @ Select(_ :@ (n: NominalType), sym) => n.sourceNominalType.sym -> (sym -> s.nodeType)
}.groupBy(_._1).mapValues(v => StructType(v.map(_._2).toMap.toIndexedSeq))
logger.debug("Found Selects for NominalTypes: "+structs.keySet.mkString(", "))

// Check for table types
val tsyms: Set[TableIdentitySymbol] =
tree.nodeType.collect { case NominalType(sym: TableIdentitySymbol, _) => sym }.toSet
logger.debug("Tables for expansion in result type: " + tsyms.mkString(", "))

val tree2 = tree.replace({
case TableExpansion(_, t, _) => t
case n => n :@ n.nodeType.replace { case NominalType(tsym, UnassignedType) => NominalType(tsym, structs(tsym)) }
}, keepType = true, bottomUp = true)

if(tsyms.isEmpty) tree2 else {
// Find the corresponding TableExpansions
val tables: Map[TableIdentitySymbol, (Symbol, Node)] = tree.collect {
case TableExpansion(s, TableNode(_, _, ts, _, _), ex) if tsyms contains ts => ts -> (s, ex)
}.toMap
logger.debug("Table expansions: " + tables.mkString(", "))
// Create a mapping that expands the tables
val sym = new AnonSymbol
val mapping = createResult(tables, Ref(sym), tree2.nodeType.asCollectionType.elementType)
.nodeWithComputedType(SymbolScope.empty + (sym -> tree2.nodeType.asCollectionType.elementType), typeChildren = true)
Bind(sym, tree2, Pure(mapping)).nodeWithComputedType()
}
}}

/** Create an expression that copies a structured value, expanding tables in it. */
def createResult(expansions: Map[TableIdentitySymbol, (Symbol, Node)], path: Node, tpe: Type): Node = tpe match {
case p: ProductType =>
ProductNode(p.numberedElements.map { case (s, t) => createResult(expansions, Select(path, s), t) }.toVector)
case NominalType(tsym: TableIdentitySymbol, _) if expansions contains tsym =>
val (sym, exp) = expansions(tsym)
exp.replace { case Ref(s) if s == sym => path }
case tpe: NominalType => createResult(expansions, path, tpe.structuralView)
case m: MappedScalaType =>
TypeMapping(createResult(expansions, path, m.baseType), m.mapper, m.classTag)
case OptionType(el) =>
val gen = new AnonSymbol
OptionFold(path, LiteralNode.nullOption, OptionApply(createResult(expansions, Ref(gen), el)), gen)
case _ => path
}
}
24 changes: 24 additions & 0 deletions slick/src/main/scala/slick/compiler/FixRowNumberOrdering.scala
@@ -0,0 +1,24 @@
package slick.compiler

import slick.ast._
import Util._

/** Inject the proper orderings into the RowNumber nodes produced earlier by
* the resolveFixJoins phase. */
class FixRowNumberOrdering extends Phase {
val name = "fixRowNumberOrdering"

def apply(state: CompilerState) =
if(state.get(Phase.resolveZipJoins).get) state.map(n => fix(n)) else state

/** Push ORDER BY into RowNumbers in ordered Comprehensions. */
def fix(n: Node, parent: Option[Comprehension] = None): Node = (n, parent) match {
case (r @ RowNumber(_), Some(c)) if !c.orderBy.isEmpty =>
RowNumber(c.orderBy).nodeTyped(r.nodeType)
case (c: Comprehension, _) => c.nodeMapScopedChildren {
case (Some(gen), ch) => fix(ch, None)
case (None, ch) => fix(ch, Some(c))
}.nodeWithComputedType()
case (n, _) => n.nodeMapChildren(ch => fix(ch, parent), keepType = true)
}
}
Expand Up @@ -5,69 +5,6 @@ import slick.ast._
import Util._
import TypeUtil._

/** Expand table-valued expressions in the result type to their star projection. */
class ExpandTables extends Phase {
val name = "expandTables"

def apply(state: CompilerState) = state.map { n => ClientSideOp.mapServerSide(n) { tree =>
// Check for table types
val tsyms: Set[TableIdentitySymbol] =
tree.nodeType.collect { case NominalType(sym: TableIdentitySymbol, _) => sym }.toSet
logger.debug("Tables for expansion in result type: " + tsyms.mkString(", "))
val tree2 = tree.replace({ case TableExpansion(_, t, _) => t }, keepType = true)
if(tsyms.isEmpty) tree2 else {
// Find the corresponding TableExpansions
val tables: Map[TableIdentitySymbol, (Symbol, Node)] = tree.collect {
case TableExpansion(s, TableNode(_, _, ts, _, _), ex) if tsyms contains ts => ts -> (s, ex)
}.toMap
logger.debug("Table expansions: " + tables.mkString(", "))
// Create a mapping that expands the tables
val sym = new AnonSymbol
val mapping = createResult(tables, Ref(sym), tree.nodeType.asCollectionType.elementType)
.nodeWithComputedType(SymbolScope.empty + (sym -> tree.nodeType.asCollectionType.elementType), typeChildren = true)
Bind(sym, tree2, Pure(mapping)).nodeWithComputedType()
}
}}

/** Create an expression that copies a structured value, expanding tables in it. */
def createResult(expansions: Map[TableIdentitySymbol, (Symbol, Node)], path: Node, tpe: Type): Node = tpe match {
case p: ProductType =>
ProductNode(p.numberedElements.map { case (s, t) => createResult(expansions, Select(path, s), t) }.toVector)
case NominalType(tsym: TableIdentitySymbol, _) if expansions contains tsym =>
val (sym, exp) = expansions(tsym)
exp.replace { case Ref(s) if s == sym => path }
case tpe: NominalType => createResult(expansions, path, tpe.structuralView)
case m: MappedScalaType =>
TypeMapping(createResult(expansions, path, m.baseType), m.mapper, m.classTag)
case OptionType(el) =>
val gen = new AnonSymbol
OptionFold(path, LiteralNode.nullOption, OptionApply(createResult(expansions, Ref(gen), el)), gen)
case _ => path
}
}

/** Expand paths of record types to reference all fields individually and
* recreate the record structure at the call site. */
class ExpandRecords extends Phase {
val name = "expandRecords"

def apply(state: CompilerState) = state.map { tree =>
tree.replace({ case n @ Path(_) => expandPath(n) }, keepType = true)
}

def expandPath(n: Node): Node = n.nodeType.structural match {
case StructType(ch) =>
StructNode(ch.map { case (s, t) =>
(s, expandPath(n.select(s).nodeTypedOrCopy(t)))
}(collection.breakOut)).nodeTyped(n.nodeType)
case p: ProductType =>
ProductNode(p.numberedElements.map { case (s, t) =>
expandPath(n.select(s).nodeTypedOrCopy(t))
}.toVector).nodeTyped(n.nodeType)
case t => n
}
}

/** Flatten all Pure node contents into a single StructNode. */
class FlattenProjections extends Phase {
val name = "flattenProjections"
Expand Down
15 changes: 3 additions & 12 deletions slick/src/main/scala/slick/compiler/InferTypes.scala
Expand Up @@ -4,19 +4,10 @@ import slick.ast._
import Util._
import TypeUtil._

/** Infer types and compute missing structural views for all nominal table types. */
/** Infer all missing types. */
class InferTypes extends Phase {
val name = "inferTypes"

def apply(state: CompilerState) = state.map { tree =>
val tree2 = tree.nodeWithComputedType(new DefaultSymbolScope(Map.empty), true, false)
val structs = tree2.collect[(TypeSymbol, (Symbol, Type))] {
case s @ Select(_ :@ (n: NominalType), sym) => n.sourceNominalType.sym -> (sym -> s.nodeType)
}.groupBy(_._1).mapValues(v => StructType(v.map(_._2).toMap.toIndexedSeq))
logger.debug("Found Selects for NominalTypes: "+structs.keySet.mkString(", "))
def tr(n: Node): Node = n.nodeMapChildren(tr, keepType = true).nodeTypedOrCopy(n.nodeType.replace {
case UnassignedStructuralType(tsym) if structs.contains(tsym) => structs(tsym)
})
tr(tree2)
}
def apply(state: CompilerState) =
state.map(_.nodeWithComputedType(new DefaultSymbolScope(Map.empty), typeChildren = true))
}
5 changes: 1 addition & 4 deletions slick/src/main/scala/slick/compiler/QueryCompiler.scala
Expand Up @@ -100,16 +100,15 @@ object QueryCompiler {
Phase.expandRecords,
Phase.flattenProjections,
/* Optimize for SQL */
Phase.createAggregates,
Phase.rewriteJoins,
Phase.verifySymbols,
Phase.assignTypes,
Phase.relabelUnions
)

/** Extra phases for translation to SQL comprehensions */
val sqlPhases = Vector(
// optional access:existsToCount goes here
Phase.createAggregates,
Phase.resolveZipJoins,
Phase.pruneProjections,
Phase.mergeToComprehensions,
Expand All @@ -123,7 +122,6 @@ object QueryCompiler {

/** Extra phases needed for the QueryInterpreter */
val interpreterPhases = Vector(
// remove createAggregates from standard phases
Phase.pruneProjections,
Phase.createResultSetMapping,
Phase.removeFieldNames
Expand Down Expand Up @@ -166,7 +164,6 @@ object Phase {
val rewriteJoins = new RewriteJoins
val verifySymbols = new VerifySymbols
val resolveZipJoins = new ResolveZipJoins
val assignTypes = new AssignTypes
val relabelUnions = new RelabelUnions
val mergeToComprehensions = new MergeToComprehensions
val fixRowNumberOrdering = new FixRowNumberOrdering
Expand Down
9 changes: 3 additions & 6 deletions slick/src/main/scala/slick/compiler/RelabelUnions.scala
Expand Up @@ -9,12 +9,9 @@ import Util._
class RelabelUnions extends Phase {
val name = "relabelUnions"

def apply(state: CompilerState) = state.map(relabelUnions)

def relabelUnions(n: Node): Node = n.replace({
def apply(state: CompilerState) = state.map(_.replace({
case u @ Union(Bind(_, _, Pure(StructNode(ls), lts)), rb @ Bind(_, _, Pure(StructNode(rs), _)), _, _, _) =>
val rs2 = (ls, rs).zipped.map { case ((s, _), (_, n)) => (s, n) }
val u2 = u.copy(right = rb.copy(select = Pure(StructNode(rs2), lts)))
u2.nodeMapChildren(relabelUnions).nodeWithComputedType()
}, keepType = true)
u.copy(right = rb.copy(select = Pure(StructNode(rs2), lts))).nodeWithComputedType()
}, keepType = true, bottomUp = true))
}
35 changes: 2 additions & 33 deletions slick/src/main/scala/slick/compiler/ResolveZipJoins.scala
@@ -1,10 +1,7 @@
package slick.compiler

import scala.collection.mutable.{HashMap, ArrayBuffer}
import slick.{SlickTreeException, SlickException}
import slick.ast._
import Util._
import ExtraUtil._
import TypeUtil._

/** Rewrite zip joins into a form suitable for SQL (using inner joins and
Expand All @@ -14,12 +11,12 @@ import TypeUtil._
* Binds need to select Pure(StructNode(...)) which should be the outcome
* of Phase.flattenProjections. */
class ResolveZipJoins extends Phase {
type State = ResolveZipJoinsState
type State = Boolean
val name = "resolveZipJoins"

def apply(state: CompilerState) = {
val n2 = resolveZipJoins(state.tree)
state + (this -> new State(n2 ne state.tree)) withNode n2
state + (this -> (n2 ne state.tree)) withNode n2
}

def resolveZipJoins(n: Node): Node = (n match {
Expand Down Expand Up @@ -60,31 +57,3 @@ class ResolveZipJoins extends Phase {
case n => n
}).nodeMapChildren(resolveZipJoins, keepType = true)
}

class ResolveZipJoinsState(val hasRowNumber: Boolean)

/** Inject the proper orderings into the RowNumber nodes produced earlier by
* the resolveFixJoins phase. */
class FixRowNumberOrdering extends Phase {
val name = "fixRowNumberOrdering"

def apply(state: CompilerState) = state.map { n =>
if(state.get(Phase.resolveZipJoins).map(_.hasRowNumber).getOrElse(true))
fixRowNumberOrdering(n)
else {
logger.debug("No row numbers to fix")
n
}
}

/** Push ORDER BY into RowNumbers in ordered Comprehensions. */
def fixRowNumberOrdering(n: Node, parent: Option[Comprehension] = None): Node = (n, parent) match {
case (r @ RowNumber(_), Some(c)) if !c.orderBy.isEmpty =>
RowNumber(c.orderBy).nodeTyped(r.nodeType)
case (c: Comprehension, _) => c.nodeMapScopedChildren {
case (Some(gen), ch) => fixRowNumberOrdering(ch, None)
case (None, ch) => fixRowNumberOrdering(ch, Some(c))
}.nodeWithComputedType()
case (n, _) => n.nodeMapChildren(ch => fixRowNumberOrdering(ch, parent), keepType = true)
}
}
10 changes: 5 additions & 5 deletions slick/src/main/scala/slick/compiler/RewriteJoins.scala
Expand Up @@ -23,19 +23,19 @@ class RewriteJoins extends Phase {
case Bind(s1, f1, Bind(s2, Filter(s3, f2, pred), select)) =>
logger.debug("Hoisting flatMapped Filter from:", Ellipsis(n, List(0), List(1, 0, 0)))
val sn, sj1, sj2 = new AnonSymbol
val j = Join(sj1, sj2, f1, f2.replace {
val j = Join(sj1, sj2, f1, f2.replace({
case Ref(s) if s == s1 => Ref(sj1) :@ f1.nodeType.asCollectionType.elementType
}, JoinType.Inner, pred.replace {
}, retype = true, bottomUp = true), JoinType.Inner, pred.replace({
case Ref(s) if s == s1 => Ref(sj1) :@ f1.nodeType.asCollectionType.elementType
case Ref(s) if s == s3 => Ref(sj2) :@ f2.nodeType.asCollectionType.elementType
}).nodeWithComputedType()
}, retype = true, bottomUp = true)).nodeWithComputedType()
val refSn = Ref(sn) :@ j.nodeType.asCollectionType.elementType
val ref1 = Select(refSn, ElementSymbol(1))
val ref2 = Select(refSn, ElementSymbol(2))
val sel2 = select.replace {
val sel2 = select.replace({
case Ref(s) :@ tpe if s == s1 => ref1 :@ tpe
case Ref(s) :@ tpe if s == s2 => ref2 :@ tpe
}
}, retype = true, bottomUp = true)
val res = Bind(sn, hoistFilters(j), sel2).nodeWithComputedType()
logger.debug("Hoisted flatMapped Filter in:", Ellipsis(res, List(0, 0), List(0, 1)))
flattenAliasingMap(res)
Expand Down

0 comments on commit 3470fa0

Please sign in to comment.