Permalink
Browse files

Improve code generation

  • Loading branch information...
szeiger committed Aug 4, 2015
1 parent aaf2a22 commit c762b55c2d900c51db318c0710435d180085327a
@@ -5,7 +5,13 @@ package slick.ast
*/
object Library {
trait AggregateFunctionSymbol extends TermSymbol
class JdbcFunction(name: String) extends FunctionSymbol(name)
class JdbcFunction(name: String) extends FunctionSymbol(name) {
override def hashCode = name.hashCode
override def equals(o: Any) = o match {
case o: JdbcFunction => name == o.name
case _ => false
}
}
class SqlFunction(name: String) extends FunctionSymbol(name)
class SqlOperator(name: String) extends FunctionSymbol(name)
class AggregateFunction(name: String) extends FunctionSymbol(name) with AggregateFunctionSymbol
@@ -89,15 +89,15 @@ class HoistClientOps extends Phase {
}
/** Rewrite remaining `GetOrElse` operations in the server-side tree into conditionals. */
def rewriteDBSide(tree: Node): Node = tree match {
case GetOrElse(ch, default) =>
def rewriteDBSide(tree: Node): Node = tree.replace({
case GetOrElse(OptionApply(ch), _) => ch
case n @ GetOrElse(ch :@ OptionType(tpe), default) =>
logger.debug("Translating GetOrElse to IfNull", n)
val d = try default() catch {
case NonFatal(ex) => throw new SlickException(
"Caught exception while computing default value for Rep[Option[_]].getOrElse -- "+
"This cannot be done lazily when the value is needed on the database side", ex)
}
val ch2 :@ OptionType(tpe) = rewriteDBSide(ch)
Library.IfNull.typed(tpe, ch2, LiteralNode.apply(tpe, d)).infer()
case n => n.mapChildren(rewriteDBSide, keepType = true)
}
Library.IfNull.typed(tpe, ch, LiteralNode(tpe, d)).infer()
}, keepType = true, bottomUp = true)
}
@@ -118,8 +118,10 @@ trait DerbyDriver extends JdbcDriver { driver =>
override val scalarFrom = Some("sysibm.sysdummy1")
class QueryBuilder(tree: Node, state: CompilerState) extends super.QueryBuilder(tree, state) {
override protected val concatOperator = Some("||")
override protected val supportsTuples = false
override protected val supportsLiteralGroupBy = true
override protected val quotedJdbcFns = Some(Vector(Library.User))
override def expr(c: Node, skipParens: Boolean = false): Unit = c match {
case Library.Cast(ch @ _*) =>
@@ -83,6 +83,7 @@ trait H2Driver extends JdbcDriver { driver =>
override protected val concatOperator = Some("||")
override protected val alwaysAliasSubqueries = false
override protected val supportsLiteralGroupBy = true
override protected val quotedJdbcFns = Some(Nil)
override def expr(n: Node, skipParens: Boolean = false) = n match {
case Library.NextValue(SequenceNode(name)) => b"nextval(schema(), '$name')"
@@ -66,6 +66,7 @@ trait HsqldbDriver extends JdbcDriver { driver =>
override protected val concatOperator = Some("||")
override protected val alwaysAliasSubqueries = false
override protected val supportsLiteralGroupBy = true
override protected val quotedJdbcFns = Some(Nil)
override def expr(c: Node, skipParens: Boolean = false): Unit = c match {
case l @ LiteralNode(v: String) if (v ne null) && jdbcTypeFor(l.nodeType).sqlType != Types.CHAR =>
@@ -103,6 +103,7 @@ trait JdbcStatementBuilderComponent { driver: JdbcDriver =>
protected val pi = "3.1415926535897932384626433832795"
protected val alwaysAliasSubqueries = true
protected val supportsLiteralGroupBy = false
protected val quotedJdbcFns: Option[Seq[Library.JdbcFunction]] = None // quote all by default
// Mutable state accessible to subclasses
protected val b = new SQLBuilder
@@ -371,9 +372,12 @@ trait JdbcStatementBuilderComponent { driver: JdbcDriver =>
} else b.sep(ch, " " + sym.name + " ")(expr(_))
b"\)"
case Apply(sym: Library.JdbcFunction, ch) =>
b"{fn ${sym.name}("
val quote = quotedJdbcFns.map(_.contains(sym)).getOrElse(true)
if(quote) b"{fn "
b"${sym.name}("
b.sep(ch, ",")(expr(_, true))
b")}"
b")"
if(quote) b"}"
case Apply(sym: Library.SqlFunction, ch) =>
b"${sym.name}("
b.sep(ch, ",")(expr(_, true))
@@ -137,6 +137,7 @@ trait MySQLDriver extends JdbcDriver { driver =>
class QueryBuilder(tree: Node, state: CompilerState) extends super.QueryBuilder(tree, state) {
override protected val supportsCast = false
override protected val parenthesizeNestedRHSJoin = true
override protected val quotedJdbcFns = Some(Nil)
override def expr(n: Node, skipParens: Boolean = false): Unit = n match {
case Library.Cast(ch) :@ JdbcType(ti, _) =>
@@ -125,6 +125,7 @@ trait PostgresDriver extends JdbcDriver { driver =>
class QueryBuilder(tree: Node, state: CompilerState) extends super.QueryBuilder(tree, state) {
override protected val concatOperator = Some("||")
override protected val quotedJdbcFns = Some(Vector(Library.Database, Library.User))
override protected def buildFetchOffsetClause(fetch: Option[Node], offset: Option[Node]) = (fetch, offset) match {
case (Some(t), Some(d)) => b"\nlimit $t offset $d"
@@ -134,6 +135,9 @@ trait PostgresDriver extends JdbcDriver { driver =>
}
override def expr(n: Node, skipParens: Boolean = false) = n match {
case Library.UCase(ch) => b"upper($ch)"
case Library.LCase(ch) => b"lower($ch)"
case Library.IfNull(ch, d) => b"coalesce($ch, $d)"
case Library.NextValue(SequenceNode(name)) => b"nextval('$name')"
case Library.CurrentValue(SequenceNode(name)) => b"currval('$name')"
case _ => super.expr(n, skipParens)
@@ -137,6 +137,7 @@ trait SQLiteDriver extends JdbcDriver { driver =>
override protected val concatOperator = Some("||")
override protected val parenthesizeNestedRHSJoin = true
override protected val alwaysAliasSubqueries = false
override protected val quotedJdbcFns = Some(Nil)
override protected def buildOrdering(n: Node, o: Ordering) {
if(o.nulls.last && !o.direction.desc)
@@ -167,18 +168,6 @@ trait SQLiteDriver extends JdbcDriver { driver =>
case Library.Floor(ch) => b"round($ch-0.5)"
case Library.User() => b"''"
case Library.Database() => b"''"
case Apply(j: Library.JdbcFunction, ch) if j != Library.Concat =>
/* The SQLite JDBC driver does not support ODBC {fn ...} escapes, so we try
* unescaped function calls by default */
b"${j.name}("
b.sep(ch, ",")(expr(_, true))
b")"
case s: SimpleFunction if s.scalar =>
/* The SQLite JDBC driver does not support ODBC {fn ...} escapes, so we try
* unescaped function calls by default */
b"${s.name}("
b.sep(s.children, ",")(expr(_, true))
b")"
case RowNumber(_) => throw new SlickException("SQLite does not support row numbers")
// https://github.com/jOOQ/jOOQ/issues/1595
case Library.Repeat(n, times) => b"replace(substr(quote(zeroblob(($times + 1) / 2)), 3, $times), '0', $n)"

0 comments on commit c762b55

Please sign in to comment.