Permalink
Browse files

Allow scalar values at the top level in the new code generator.

ScalarFunctionTest now compiles and runs successfully with H2.
  • Loading branch information...
1 parent 6218f0f commit 92ff514606d13eb30e65541e46ebfebdb18e1562 @szeiger committed Mar 16, 2012
@@ -72,37 +72,37 @@ object SimpleBinaryOperator {
case class SimpleLiteral(name: String) extends NullaryNode
trait SimpleExpression extends SimpleNode {
- def toSQL(b: SQLBuilder, qb: BasicQueryBuilder): Unit
+ def toSQL(qb: BasicQueryBuilder): Unit
}
object SimpleExpression {
- def apply[T : TypeMapper](f: (Seq[Node], SQLBuilder, BasicQueryBuilder) => Unit): (Seq[Column[_]] => OperatorColumn[T] with SimpleExpression) = {
+ def apply[T : TypeMapper](f: (Seq[Node], BasicQueryBuilder) => Unit): (Seq[Column[_]] => OperatorColumn[T] with SimpleExpression) = {
lazy val builder: (Seq[NodeGenerator] => OperatorColumn[T] with SimpleExpression) = paramsC =>
new OperatorColumn[T] with SimpleExpression {
- def toSQL(b: SQLBuilder, qb: BasicQueryBuilder) = f(nodeChildren.toSeq, b, qb)
+ def toSQL(qb: BasicQueryBuilder) = f(nodeChildren.toSeq, qb)
protected[this] def nodeChildGenerators = paramsC
protected[this] def nodeRebuild(ch: IndexedSeq[Node]): Node = builder(ch)
}
builder
}
- def nullary[R : TypeMapper](f: (SQLBuilder, BasicQueryBuilder) => Unit): OperatorColumn[R] with SimpleExpression = {
- val g = apply({ (ch: Seq[Node], b: SQLBuilder, qb: BasicQueryBuilder) => f(b, qb) });
+ def nullary[R : TypeMapper](f: BasicQueryBuilder => Unit): OperatorColumn[R] with SimpleExpression = {
+ val g = apply({ (ch: Seq[Node], qb: BasicQueryBuilder) => f(qb) });
g.apply(Seq())
}
- def unary[T1, R : TypeMapper](f: (Node, SQLBuilder, BasicQueryBuilder) => Unit): (Column[T1] => OperatorColumn[R] with SimpleExpression) = {
- val g = apply({ (ch: Seq[Node], b: SQLBuilder, qb: BasicQueryBuilder) => f(ch(0), b, qb) });
+ def unary[T1, R : TypeMapper](f: (Node, BasicQueryBuilder) => Unit): (Column[T1] => OperatorColumn[R] with SimpleExpression) = {
+ val g = apply({ (ch: Seq[Node], qb: BasicQueryBuilder) => f(ch(0), qb) });
{ t1: Column[T1] => g(Seq(t1)) }
}
- def binary[T1, T2, R : TypeMapper](f: (Node, Node, SQLBuilder, BasicQueryBuilder) => Unit): ((Column[T1], Column[T2]) => OperatorColumn[R] with SimpleExpression) = {
- val g = apply({ (ch: Seq[Node], b: SQLBuilder, qb: BasicQueryBuilder) => f(ch(0), ch(1), b, qb) });
+ def binary[T1, T2, R : TypeMapper](f: (Node, Node, BasicQueryBuilder) => Unit): ((Column[T1], Column[T2]) => OperatorColumn[R] with SimpleExpression) = {
+ val g = apply({ (ch: Seq[Node], qb: BasicQueryBuilder) => f(ch(0), ch(1), qb) });
{ (t1: Column[T1], t2: Column[T2]) => g(Seq(t1, t2)) }
}
- def ternary[T1, T2, T3, R : TypeMapper](f: (Node, Node, Node, SQLBuilder, BasicQueryBuilder) => Unit): ((Column[T1], Column[T2], Column[T3]) => OperatorColumn[R] with SimpleExpression) = {
- val g = apply({ (ch: Seq[Node], b: SQLBuilder, qb: BasicQueryBuilder) => f(ch(0), ch(1), ch(2), b, qb) });
+ def ternary[T1, T2, T3, R : TypeMapper](f: (Node, Node, Node, BasicQueryBuilder) => Unit): ((Column[T1], Column[T2], Column[T3]) => OperatorColumn[R] with SimpleExpression) = {
+ val g = apply({ (ch: Seq[Node], qb: BasicQueryBuilder) => f(ch(0), ch(1), ch(2), qb) });
{ (t1: Column[T1], t2: Column[T2], t3: Column[T3]) => g(Seq(t1, t2, t3)) }
}
}
@@ -36,7 +36,7 @@ class BasicQueryBuilder(_query: Query[_, _], _profile: BasicProfile) {
if(from.length <= 1) b += "*"
else b += symbolName(from.last._1) += ".*"
}
- if(from.isEmpty) scalarFrom.foreach { s => b += " from " += s }
+ if(from.isEmpty) buildScalarFrom
else {
b += " from "
b.sep(from, ", ") { case (sym, n) =>
@@ -51,6 +51,10 @@ class BasicQueryBuilder(_query: Query[_, _], _profile: BasicProfile) {
case Pure(CountAll(q)) =>
b += "select count(*) from "
buildFrom(q, None)
+ case p @ Pure(_) =>
+ b += "select "
+ buildSelectClause(p)
+ buildScalarFrom
case AbstractTable(name) =>
b += "select * from " += quoteIdentifier(name)
case TakeDrop(from, take, drop) => buildTakeDrop(from, take, drop)
@@ -62,6 +66,8 @@ class BasicQueryBuilder(_query: Query[_, _], _profile: BasicProfile) {
case n => throw new SQueryException("Unexpected node "+n+" -- SQL prefix: "+b.build.sql)
}
+ protected def buildScalarFrom: Unit = scalarFrom.foreach { s => b += " from " += s }
+
protected def buildTakeDrop(from: Node, take: Option[Int], drop: Option[Int]) {
if(take == Some(0)) {
b += "select * from "
@@ -121,7 +127,7 @@ class BasicQueryBuilder(_query: Query[_, _], _profile: BasicProfile) {
case s => quoteIdentifier(s.name)
}
- protected def expr(n: Node): Unit = n match {
+ def expr(n: Node): Unit = n match {
case ConstColumn(null) => b += "null"
case Not(Is(l, ConstColumn(null))) => b += '('; expr(l); b += " is not null)"
case Not(e) => b += "(not "; expr(e); b+= ')'
@@ -142,7 +148,7 @@ class BasicQueryBuilder(_query: Query[_, _], _profile: BasicProfile) {
b += ')'
if(s.scalar) b += '}'
case SimpleLiteral(w) => b += w
- case s: SimpleExpression => s.toSQL(b, this)
+ case s: SimpleExpression => s.toSQL(this)
case Between(left, start, end) => expr(left); b += " between "; expr(start); b += " and "; expr(end)
case CountDistinct(e) => b += "count(distinct "; expr(e); b += ')'
case Like(l, r, esc) =>
@@ -26,7 +26,7 @@ class H2QueryBuilder(_query: Query[_, _], profile: H2Driver) extends BasicQueryB
override protected val mayLimit0 = false
override protected val concatOperator = Some("||")
- override protected def expr(n: Node) = n match {
+ override def expr(n: Node) = n match {
case Sequence.Nextval(seq) => b += "nextval(schema(), '" += seq.name += "')"
case Sequence.Currval(seq) => b += "currval(schema(), '" += seq.name += "')"
case _ => super.expr(n)
@@ -11,7 +11,6 @@ import org.scalaquery.session._
import org.scalaquery.session.Database.threadLocalSession
import org.scalaquery.test.util._
import org.scalaquery.test.util.TestDB._
-import org.scalaquery.util.{SQLBuilder, BinaryNode, Node}
import java.sql.{Time, Date, Timestamp}
object ScalarFunctionTest extends DBTestObject(H2Mem, SQLiteMem, Postgres, MySQL, DerbyMem, HsqldbMem, MSAccess, SQLServer)
@@ -76,12 +75,12 @@ class ScalarFunctionTest(tdb: TestDB) extends DBTest(tdb) {
check(Query(Functions.pi.toDegrees), 180.0)
check(Query(Functions.pi.toDegrees.toRadians is Functions.pi), true)
- val myExpr = SimpleExpression.binary[Int, Int, Int] { (l, r, b, qb) =>
- b += '('
- qb.expr(l, b)
- b += '+'
- qb.expr(r, b)
- b += "+1)"
+ val myExpr = SimpleExpression.binary[Int, Int, Int] { (l, r, qb) =>
+ qb.sqlBuilder += '('
+ qb.expr(l)
+ qb.sqlBuilder += '+'
+ qb.expr(r)
+ qb.sqlBuilder += "+1)"
}
check(Query(myExpr(4, 5)), 10)

0 comments on commit 92ff514

Please sign in to comment.