-
-
Notifications
You must be signed in to change notification settings - Fork 609
/
SpecializeParameters.scala
39 lines (34 loc) · 1.58 KB
/
SpecializeParameters.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
package slick.compiler
import slick.ast._
import slick.ast.Util._
import slick.util.ConstArray
/** Specialize the AST for edge cases of query parameters. This is required for
* compiling `take(0)` for some databases which do not allow `LIMIT 0`. */
class SpecializeParameters extends Phase {
val name = "specializeParameters"
def apply(state: CompilerState): CompilerState =
state.map(ClientSideOp.mapServerSide(_)(transformServerSide))
def transformServerSide(n: Node): Node = {
val cs =
n.collect { case c @ Comprehension(_, _, _, _, _, _, _, _, Some(_: QueryParameter), _, _) =>
c.asInstanceOf[Comprehension[Some[QueryParameter]]]
}
logger.debug("Affected fetch clauses in: "+cs.mkString(", "))
cs.foldLeft(n) { case (n, c @ Comprehension(_, _, _, _, _, _, _, _, Some(fetch: QueryParameter), _, _)) =>
val compiledFetchParam = QueryParameter(fetch.extractor, ScalaBaseType.longType)
val guarded =
n.replace({ case c2: Comprehension.Base if c2 == c => c2.copy(fetch = Some(LiteralNode(0L))) }, keepType = true)
val fallback =
n.replace(
{ case c2: Comprehension.Base if c2 == c => c2.copy(fetch = Some(compiledFetchParam)) },
keepType = true
)
ParameterSwitch(ConstArray(compare(fetch.extractor, 0L) -> guarded), fallback).infer()
}
}
/** Create a function that calls an extractor for a value and compares the result with a fixed value. */
def compare(f: Any => Any, v: Any) = new (Any => Boolean) {
def apply(param: Any) = v == f(param)
override def toString = s"$f(...) == $v"
}
}