/
UnrollTailBinds.scala
62 lines (51 loc) · 1.65 KB
/
UnrollTailBinds.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
package slick.compiler
import slick.ast.Library.AggregateFunctionSymbol
import scala.collection.mutable.{HashSet, HashMap}
import slick.SlickException
import slick.ast._
import TypeUtil._
import Util._
class UnrollTailBinds extends Phase {
val name = "unrollTailBinds"
def apply(state: CompilerState) = state.map(tr(_))
def tr(n: Node): Node = {
n match {
case bb@Bind(br,
downstream:Bind,
Bind(bi, bf, select)) => {
// make a new symbol
val bm = new AnonSymbol
def rep(node: Node) = {
def repInternal(node: Node): Node = {
val out = node.replace({
case p@Path(brs :: tail) if brs == br => {
Path(List(bm) ++ tail)
}
}, bottomUp = true, keepType = true)
out.mapChildren(repInternal)
}
repInternal(node)
}
def replaceDownstream(node:Bind):Option[Node] = {
node match {
case Bind(bo, Filter(fa, ff1, where1), Filter(fb, ff2, where2)) => {
val out = Bind(bo,
Filter(fa, rep(ff1), rep(where1)),
Bind(bm,
Filter(fb, rep(ff2), rep(where2)),
Bind(bi, rep(bf), rep(select))))
Some(out)
}
case Bind(bo1, bi1:Bind, sel1) => replaceDownstream(bi1).map(
replacement => Bind(bo1, replacement, sel1))
case n => None
}
}
// bind all needed elements to the new symbol
val bindCombo = replaceDownstream(downstream).getOrElse(bb)
bindCombo.mapChildren(tr)
}
case n => n.mapChildren(tr)
}
}
}