Permalink
Browse files

Better integration of expandConditionals into expandSums

  • Loading branch information...
szeiger committed Sep 9, 2015
1 parent 07cd37a commit f82a8fea0f2ec6f4e3894b2dcb51d84db2306d84
Showing with 80 additions and 83 deletions.
  1. +80 −83 slick/src/main/scala/slick/compiler/ExpandSums.scala
@@ -13,97 +13,94 @@ class ExpandSums extends Phase {
val name = "expandSums"
def apply(state: CompilerState) =
if(state.get(Phase.assignUniqueSymbols).map(_.nonPrimitiveOption).getOrElse(true)) state.map(expand)
if(state.get(Phase.assignUniqueSymbols).map(_.nonPrimitiveOption).getOrElse(true)) state.map(expandSums)
else state
val Disc1 = LiteralNode(ScalaBaseType.optionDiscType.optionType, Option(1))
val DiscNone = LiteralNode(ScalaBaseType.optionDiscType.optionType, None)
def expand(n: Node): Node = {
val (n2, multi) = tr(n, Set.empty)
if(multi) expandConditionals(n2) else n2
}
/** Perform the sum expansion on a Node */
def tr(tree: Node, oldDiscCandidates: Set[(TypeSymbol, List[TermSymbol])]): (Node, Boolean) = {
val discCandidates = oldDiscCandidates ++ (tree match {
case Filter(_, _, p) => collectDiscriminatorCandidates(p)
case Bind(_, j: Join, _) => collectDiscriminatorCandidates(j.on)
case _ => Set.empty
})
def expandSums(n: Node): Node = {
var multi = false
val tree2 = tree.mapChildren({ n =>
val (n2, flag) = tr(n, discCandidates)
multi |= flag
n2
}, keepType = true)
val tree3 = tree2 match {
// Expand multi-column null values in ELSE branches (used by Rep[Option].filter) with correct type
case IfThenElse(ConstArray(pred, then1 :@ tpe, LiteralNode(None) :@ OptionType(ScalaBaseType.nullType))) =>
multi = true
IfThenElse(ConstArray(pred, then1, buildMultiColumnNone(tpe))) :@ tpe
// Primitive OptionFold representing GetOrElse -> translate to GetOrElse
case OptionFold(from :@ OptionType.Primitive(_), LiteralNode(v), Ref(s), gen) if s == gen =>
GetOrElse(from, () => v).infer()
// Primitive OptionFold -> translate to null check
case OptionFold(from :@ OptionType.Primitive(_), ifEmpty, map, gen) =>
val pred = Library.==.typed[Boolean](from, LiteralNode(null))
val n2 = (ifEmpty, map) match {
case (LiteralNode(true), LiteralNode(false)) => pred
case (LiteralNode(false), LiteralNode(true)) => Library.Not.typed[Boolean](pred)
case _ =>
val ifDefined = map.replace({
case r @ Ref(s) if s == gen => silentCast(r.nodeType, from)
}, keepType = true)
val ifEmpty2 = silentCast(ifDefined.nodeType.structural, ifEmpty)
IfThenElse(ConstArray(pred, ifEmpty2, ifDefined))
}
n2.infer()
// Other OptionFold -> translate to discriminator check
case OptionFold(from, ifEmpty, map, gen) =>
multi = true
val left = from.select(ElementSymbol(1)).infer()
val pred = Library.==.typed[Boolean](left, LiteralNode(null))
val n2 = (ifEmpty, map) match {
case (LiteralNode(true), LiteralNode(false)) => pred
case (LiteralNode(false), LiteralNode(true)) => Library.Not.typed[Boolean](pred)
case _ =>
val ifDefined = map.replace({
case r @ Ref(s) if s == gen => silentCast(r.nodeType, from.select(ElementSymbol(2)).infer())
}, keepType = true)
val ifEmpty2 = silentCast(ifDefined.nodeType.structural, ifEmpty)
if(left == Disc1) ifDefined else IfThenElse(ConstArray(Library.Not.typed[Boolean](pred), ifDefined, ifEmpty2))
}
n2.infer()
// Primitive OptionApply -> leave unchanged
case n @ OptionApply(_) :@ OptionType.Primitive(_) => n
// Other OptionApply -> translate to product form
case n @ OptionApply(ch) =>
multi = true
ProductNode(ConstArray(Disc1, silentCast(toOptionColumns(ch.nodeType), ch))).infer()
// Non-primitive GetOrElse
// (.get is only defined on primitive Options, but this can occur inside of HOFs like .map)
case g @ GetOrElse(ch :@ tpe, _) =>
tpe match {
case OptionType.Primitive(_) => g
case _ => throw new SlickException(".get may only be called on Options of top-level primitive types")
}
// Option-extended left outer, right outer or full outer join
case bind @ Bind(bsym, Join(_, _, _, _, jt, _), _) if jt == JoinType.LeftOption || jt == JoinType.RightOption || jt == JoinType.OuterOption =>
multi = true
translateJoin(bind, discCandidates)
case n => n
/** Perform the sum expansion on a Node */
def tr(tree: Node, oldDiscCandidates: Set[(TypeSymbol, List[TermSymbol])]): Node = {
val discCandidates = oldDiscCandidates ++ (tree match {
case Filter(_, _, p) => collectDiscriminatorCandidates(p)
case Bind(_, j: Join, _) => collectDiscriminatorCandidates(j.on)
case _ => Set.empty
})
val tree2 = tree.mapChildren(tr(_, discCandidates), keepType = true)
val tree3 = tree2 match {
// Expand multi-column null values in ELSE branches (used by Rep[Option].filter) with correct type
case IfThenElse(ConstArray(pred, then1 :@ tpe, LiteralNode(None) :@ OptionType(ScalaBaseType.nullType))) =>
multi = true
IfThenElse(ConstArray(pred, then1, buildMultiColumnNone(tpe))) :@ tpe
// Primitive OptionFold representing GetOrElse -> translate to GetOrElse
case OptionFold(from :@ OptionType.Primitive(_), LiteralNode(v), Ref(s), gen) if s == gen =>
GetOrElse(from, () => v).infer()
// Primitive OptionFold -> translate to null check
case OptionFold(from :@ OptionType.Primitive(_), ifEmpty, map, gen) =>
val pred = Library.==.typed[Boolean](from, LiteralNode(null))
val n2 = (ifEmpty, map) match {
case (LiteralNode(true), LiteralNode(false)) => pred
case (LiteralNode(false), LiteralNode(true)) => Library.Not.typed[Boolean](pred)
case _ =>
val ifDefined = map.replace({
case r @ Ref(s) if s == gen => silentCast(r.nodeType, from)
}, keepType = true)
val ifEmpty2 = silentCast(ifDefined.nodeType.structural, ifEmpty)
IfThenElse(ConstArray(pred, ifEmpty2, ifDefined))
}
n2.infer()
// Other OptionFold -> translate to discriminator check
case OptionFold(from, ifEmpty, map, gen) =>
multi = true
val left = from.select(ElementSymbol(1)).infer()
val pred = Library.==.typed[Boolean](left, LiteralNode(null))
val n2 = (ifEmpty, map) match {
case (LiteralNode(true), LiteralNode(false)) => pred
case (LiteralNode(false), LiteralNode(true)) => Library.Not.typed[Boolean](pred)
case _ =>
val ifDefined = map.replace({
case r @ Ref(s) if s == gen => silentCast(r.nodeType, from.select(ElementSymbol(2)).infer())
}, keepType = true)
val ifEmpty2 = silentCast(ifDefined.nodeType.structural, ifEmpty)
if(left == Disc1) ifDefined else IfThenElse(ConstArray(Library.Not.typed[Boolean](pred), ifDefined, ifEmpty2))
}
n2.infer()
// Primitive OptionApply -> leave unchanged
case n @ OptionApply(_) :@ OptionType.Primitive(_) => n
// Other OptionApply -> translate to product form
case n @ OptionApply(ch) =>
multi = true
ProductNode(ConstArray(Disc1, silentCast(toOptionColumns(ch.nodeType), ch))).infer()
// Non-primitive GetOrElse
// (.get is only defined on primitive Options, but this can occur inside of HOFs like .map)
case g @ GetOrElse(ch :@ tpe, _) =>
tpe match {
case OptionType.Primitive(_) => g
case _ => throw new SlickException(".get may only be called on Options of top-level primitive types")
}
// Option-extended left outer, right outer or full outer join
case bind @ Bind(bsym, Join(_, _, _, _, jt, _), _) if jt == JoinType.LeftOption || jt == JoinType.RightOption || jt == JoinType.OuterOption =>
multi = true
translateJoin(bind, discCandidates)
case n => n
}
val tree4 = fuse(tree3)
tree4 :@ trType(tree4.nodeType)
}
val tree4 = fuse(tree3)
(tree4 :@ trType(tree4.nodeType), multi)
val n2 = tr(n, Set.empty)
if(multi) expandConditionals(n2) else n2
}
/** Translate an Option-extended left outer, right outer or full outer join */

0 comments on commit f82a8fe

Please sign in to comment.