/
ReorderOperations.scala
103 lines (86 loc) · 4.36 KB
/
ReorderOperations.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package slick.compiler
import slick.ast._
import slick.ast.Util._
import slick.ast.TypeUtil._
import slick.util.{ConstArray, Ellipsis, ??}
/** Reorder certain stream operations for more efficient merging in `mergeToComprehensions`. */
class ReorderOperations extends Phase {
val name = "reorderOperations"
def apply(state: CompilerState) = state.map(convert)
def convert(tree: Node): Node = tree.replace({ case n => convert1(n) }, keepType = true, bottomUp = true)
def convert1(tree: Node): Node = tree match {
// Push Bind into Union
case n @ Bind(s1, Union(l1, r1, all), sel) =>
logger.debug("Pushing Bind into both sides of a Union", Ellipsis(n, List(0, 0), List(0, 1)))
val s1l, s1r = new AnonSymbol
val n2 = Union(
Bind(s1l, l1, sel.replace { case Ref(s) if s == s1 => Ref(s1l) }),
Bind(s1r, r1, sel.replace { case Ref(s) if s == s1 => Ref(s1r) }),
all).infer()
logger.debug("Pushed Bind into both sides of a Union", Ellipsis(n2, List(0, 0), List(1, 0)))
n2
// Push Filter into Union
case n @ Filter(s1, Union(l1, r1, all), pred) =>
logger.debug("Pushing Filter into both sides of a Union", Ellipsis(n, List(0, 0), List(0, 1)))
val s1l, s1r = new AnonSymbol
val n2 = Union(
Filter(s1l, l1, pred.replace { case Ref(s) if s == s1 => Ref(s1l) }),
Filter(s1r, r1, pred.replace { case Ref(s) if s == s1 => Ref(s1r) }),
all).infer()
logger.debug("Pushed Filter into both sides of a Union", Ellipsis(n2, List(0, 0), List(1, 0)))
n2
// Push CollectionCast into Union
case n @ CollectionCast(Union(l1, r1, all), cons) =>
logger.debug("Pushing CollectionCast into both sides of a Union", Ellipsis(n, List(0, 0), List(0, 1)))
val n2 = Union(CollectionCast(l1, cons), CollectionCast(r1, cons), all).infer()
logger.debug("Pushed CollectionCast into both sides of a Union", Ellipsis(n2, List(0, 0), List(1, 0)))
n2
// Remove Subquery boundary on top of TableNode and Join
case Subquery(n @ (_: TableNode | _: Join), _) => n
// Push distincness-preserving aliasing / literal projection into Subquery.AboveDistinct
case n @ Bind(s, Subquery(from :@ CollectionType(_, tpe), Subquery.AboveDistinct), Pure(StructNode(defs), ts1))
if isAliasingOrLiteral(s, defs) && isDistinctnessPreserving(s, defs, tpe) =>
Subquery(n.copy(from = from), Subquery.AboveDistinct).infer()
// Push any aliasing / literal projection into other Subquery
case n @ Bind(s, Subquery(from, cond), Pure(StructNode(defs), ts1)) if cond != Subquery.AboveDistinct && isAliasingOrLiteral(s, defs) =>
Subquery(n.copy(from = from), cond).infer()
// If a Filter checks an upper bound of a ROWNUM, push it into the AboveRownum boundary
case filter @ Filter(s1,
sq @ Subquery(bind @ Bind(bs1, from1, Pure(StructNode(defs1), ts1)), Subquery.AboveRownum),
Apply(Library.<= | Library.<, ConstArray(Select(Ref(rs), f1), v1)))
if rs == s1 && defs1.find {
case (f, n) if f == f1 => isRownumCalculation(n)
case _ => false
}.isDefined =>
sq.copy(child = filter.copy(from = bind)).infer()
// Push a BelowRowNumber boundary into SortBy
case sq @ Subquery(n: SortBy, Subquery.BelowRowNumber) =>
n.copy(from = convert1(sq.copy(child = n.from))).infer()
// Push a BelowRowNumber boundary into Filter
case sq @ Subquery(n: Filter, Subquery.BelowRowNumber) =>
n.copy(from = convert1(sq.copy(child = n.from))).infer()
case n => n
}
def isAliasingOrLiteral(base: TermSymbol, defs: ConstArray[(TermSymbol, Node)]) = {
val r = defs.iterator.map(_._2).forall {
case FwdPath(s :: _) if s == base => true
case _: LiteralNode => true
case _: QueryParameter => true
case _ => false
}
logger.debug("Bind from "+base+" is aliasing / literal: "+r)
r
}
def isDistinctnessPreserving(base: TermSymbol, defs: ConstArray[(TermSymbol, Node)], tpe: Type) = {
val usedFields = defs.flatMap(_._2.collect[TermSymbol] {
case Select(Ref(s), f) if s == base => f
})
val StructType(tDefs) = tpe.structural
(tDefs.map(_._1).toSet -- usedFields.toSeq).isEmpty
}
def isRownumCalculation(n: Node): Boolean = n match {
case Apply(Library.+ | Library.-, ch) => ch.exists(isRownumCalculation)
case _: RowNumber => true
case _ => false
}
}