Permalink
Browse files

Integrate expandConditionals into expandSums

With the ill-typed outer join operators gone in 3.1, emulateOuterJoins
does not have to interact with Option expansion anymore, so there is no
need to have an `expandConditionals` phase separate from `expandSums`.
Integrating it avoids the use of phase state in `expandSums`.
  • Loading branch information...
szeiger committed Sep 7, 2015
1 parent d286c17 commit 6cb62b8778fea22341c32751a752c3bc0bc81537

This file was deleted.

Oops, something went wrong.
@@ -1,31 +1,29 @@
package slick.compiler
import slick.util.ConstArray
import slick.util.{ConstArrayOp, ConstArray}
import slick.{SlickTreeException, SlickException}
import slick.ast._
import Util._
import TypeUtil._
import scala.collection.mutable
/** Expand sum types and their catamorphisms to equivalent product type operations.
* The phase state is a flag indicating whether there is anything left to clean up in
* `expandConditionals`. */
/** Expand sum types and their catamorphisms to equivalent product type operations. */
class ExpandSums extends Phase {
val name = "expandSums"
type State = Boolean
def apply(state: CompilerState) = {
if(state.get(Phase.assignUniqueSymbols).map(_.nonPrimitiveOption).getOrElse(true)) {
val (n, multi) = tr(state.tree, Set.empty)
state.withNode(n) + (this -> multi)
} else state + (this -> false)
}
def apply(state: CompilerState) =
if(state.get(Phase.assignUniqueSymbols).map(_.nonPrimitiveOption).getOrElse(true)) state.map(expand)
else state
val Disc1 = LiteralNode(ScalaBaseType.optionDiscType.optionType, Option(1))
val DiscNone = LiteralNode(ScalaBaseType.optionDiscType.optionType, None)
def expand(n: Node): Node = {
val (n2, multi) = tr(n, Set.empty)
if(multi) expandConditionals(n2) else n2
}
/** Perform the sum expansion on a Node */
def tr(tree: Node, oldDiscCandidates: Set[(TypeSymbol, List[TermSymbol])]): (Node, Boolean) = {
val discCandidates = oldDiscCandidates ++ (tree match {
@@ -262,4 +260,66 @@ class ExpandSums extends Phase {
case _ => None
}
}
/** Expand multi-column conditional expressions and SilentCasts.
* Single-column conditionals involving NULL values are optimized away where possible. */
def expandConditionals(n: Node): Node = {
val invalid = mutable.HashSet.empty[TypeSymbol]
def invalidate(n: Node): Unit = invalid ++= n.nodeType.collect { case NominalType(ts, _) => ts }.toSeq
def tr(n: Node): Node = n.mapChildren(tr, keepType = true) match {
// Expand multi-column SilentCasts
case cast @ Library.SilentCast(ch) :@ Type.Structural(ProductType(typeCh)) =>
invalidate(ch)
val elems = typeCh.zipWithIndex.map { case (t, idx) => tr(Library.SilentCast.typed(t, ch.select(ElementSymbol(idx+1))).infer()) }
ProductNode(elems).infer()
case Library.SilentCast(ch) :@ Type.Structural(StructType(typeCh)) =>
invalidate(ch)
val elems = typeCh.map { case (sym, t) => (sym, tr(Library.SilentCast.typed(t, ch.select(sym)).infer())) }
StructNode(elems).infer()
// Optimize trivial SilentCasts
case Library.SilentCast(v :@ tpe) :@ tpe2 if tpe.structural == tpe2.structural =>
invalidate(v)
v
case Library.SilentCast(Library.SilentCast(ch)) :@ tpe => tr(Library.SilentCast.typed(tpe, ch).infer())
case Library.SilentCast(LiteralNode(None)) :@ (tpe @ OptionType.Primitive(_)) => LiteralNode(tpe, None).infer()
// Expand multi-column IfThenElse
case (cond @ IfThenElse(_)) :@ Type.Structural(ProductType(chTypes)) =>
val ch = ConstArrayOp.from(1 to chTypes.length).map { idx =>
val sym = ElementSymbol(idx)
tr(cond.mapResultClauses(n => n.select(sym)).infer())
}
ProductNode(ch).infer()
case (cond @ IfThenElse(_)) :@ Type.Structural(StructType(chTypes)) =>
val ch = chTypes.map { case (sym, _) =>
(sym, tr(cond.mapResultClauses(n => n.select(sym)).infer()))
}
StructNode(ch).infer()
// Optimize null-propagating single-column IfThenElse
case IfThenElse(ConstArray(Library.==(r, LiteralNode(null)), Library.SilentCast(LiteralNode(None)), c @ Library.SilentCast(r2))) if r == r2 => c
// Fix Untyped nulls in else clauses
case cond @ IfThenElse(clauses) if (clauses.last match { case LiteralNode(None) :@ OptionType(ScalaBaseType.nullType) => true; case _ => false }) =>
cond.copy(clauses.init :+ LiteralNode(cond.nodeType, None))
// Resolve Selects into ProductNodes and StructNodes
case Select(ProductNode(ch), ElementSymbol(idx)) => ch(idx-1)
case Select(StructNode(ch), sym) => ch.find(_._1 == sym).get._2
case n2 @ Pure(_, ts) if n2 ne n =>
invalid += ts
n2
case n => n
}
val n2 = tr(n)
logger.debug("Invalidated TypeSymbols: "+invalid.mkString(", "))
n2.replace({
case n: PathElement if n.nodeType.containsSymbol(invalid) => n.untyped
}, bottomUp = true).infer()
}
}
@@ -113,7 +113,6 @@ object QueryCompiler {
Phase.expandSums,
// optional removeTakeDrop goes here
// optional emulateOuterJoins goes here
Phase.expandConditionals,
Phase.expandRecords,
Phase.flattenProjections,
/* Optimize for SQL */
@@ -176,7 +175,6 @@ object Phase {
val forceOuterBinds = new ForceOuterBinds
val removeMappedTypes = new RemoveMappedTypes
val expandSums = new ExpandSums
val expandConditionals = new ExpandConditionals
val expandRecords = new ExpandRecords
val flattenProjections = new FlattenProjections
val createAggregates = new CreateAggregates
@@ -54,7 +54,7 @@ trait RelationalProfile extends BasicProfile with RelationalTableComponent
val canJoinRight = capabilities contains RelationalProfile.capabilities.joinRight
val canJoinFull = capabilities contains RelationalProfile.capabilities.joinFull
if(canJoinLeft && canJoinRight && canJoinFull) base
else base.addBefore(new EmulateOuterJoins(canJoinLeft, canJoinRight), Phase.expandConditionals)
else base.addBefore(new EmulateOuterJoins(canJoinLeft, canJoinRight), Phase.expandRecords)
}
class TableQueryExtensionMethods[T <: Table[_], U](val q: Query[T, U, Seq] with TableQuery[T]) {

0 comments on commit 6cb62b8

Please sign in to comment.