Skip to content

Commit

Permalink
Split MiscTest into RelationalMiscTest and JdbcMiscTest.
Browse files Browse the repository at this point in the history
- All misc. tests except the ones for "fake" nullability can run against
  RelationalProfile.
- The IfThen clauses in ConditionalExpr nodes were in reverse order for
  no good reason. Changing to declaration order.
- Support ConditionalExpr and Like in QueryInterpreter
  • Loading branch information
szeiger committed May 7, 2013
1 parent 688961a commit 67582b4
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 78 deletions.
@@ -0,0 +1,45 @@
package com.typesafe.slick.testkit.tests

import org.junit.Assert._
import com.typesafe.slick.testkit.util.{JdbcTestDB, TestkitTest}

class JdbcMiscTest extends TestkitTest[JdbcTestDB] {
import tdb.profile.simple._

override val reuseInstance = true

def testNullability {
object T1 extends Table[String]("t1") {
def a = column[String]("a")
def * = a
}

object T2 extends Table[String]("t2") {
def a = column[String]("a", O.Nullable)
def * = a
}

object T3 extends Table[Option[String]]("t3") {
def a = column[Option[String]]("a")
def * = a
}

object T4 extends Table[Option[String]]("t4") {
def a = column[Option[String]]("a", O.NotNull)
def * = a
}

(T1.ddl ++ T2.ddl ++ T3.ddl ++ T4.ddl).create

T1.insert("a")
T2.insert("a")
T3.insert(Some("a"))
T4.insert(Some("a"))

T2.insert(null.asInstanceOf[String])
T3.insert(None)

assertFail { T1.insert(null.asInstanceOf[String]) }
assertFail { T4.insert(None) }
}
}
@@ -1,9 +1,9 @@
package com.typesafe.slick.testkit.tests

import org.junit.Assert._
import com.typesafe.slick.testkit.util.{JdbcTestDB, TestkitTest}
import com.typesafe.slick.testkit.util.{RelationalTestDB, TestkitTest}

class MiscTest extends TestkitTest[JdbcTestDB] {
class RelationalMiscTest extends TestkitTest[RelationalTestDB] {
import tdb.profile.simple._

override val reuseInstance = true
Expand All @@ -16,66 +16,19 @@ class MiscTest extends TestkitTest[JdbcTestDB] {
}

T.ddl.create
T.insertAll(("1", "a"), ("2", "a"), ("3", "b"))
T ++= Seq(("1", "a"), ("2", "a"), ("3", "b"))

val q1 = for(t <- T if t.a === "1" || t.a === "2") yield t
println("q1: "+q1.selectStatement)
q1.foreach(println _)
assertEquals(q1.to[Set], Set(("1", "a"), ("2", "a")))
assertEquals(Set(("1", "a"), ("2", "a")), q1.run.toSet)

val q2 = for(t <- T if (t.a isNot "1") || (t.b isNot "a")) yield t
println("q2: "+q2.selectStatement)
q2.foreach(println _)
assertEquals(q2.to[Set], Set(("2", "a"), ("3", "b")))
assertEquals(Set(("2", "a"), ("3", "b")), q2.run.toSet)

// No need to test that the unexpected result is actually unexpected
// now that the compiler prints a warning about it
/*
val q3 = for(t <- T if (t.a != "1") || (t.b != "a")) yield t
println("q3: "+q3.selectStatement) // Hah, not what you expect!
q3.foreach(println _)
assertEquals(q3.to[Set], Set(("1", "a"), ("2", "a"), ("3", "b")))
*/

val q4 = for(t <- T if t.a =!= "1" || t.b =!= "a") yield t
println("q4: "+q4.selectStatement)
q4.foreach(println _)
assertEquals(q4.to[Set], Set(("2", "a"), ("3", "b")))
}

def testNullability {
object T1 extends Table[String]("t1") {
def a = column[String]("a")
def * = a
}

object T2 extends Table[String]("t2") {
def a = column[String]("a", O.Nullable)
def * = a
}

object T3 extends Table[Option[String]]("t3") {
def a = column[Option[String]]("a")
def * = a
}

object T4 extends Table[Option[String]]("t4") {
def a = column[Option[String]]("a", O.NotNull)
def * = a
}

(T1.ddl ++ T2.ddl ++ T3.ddl ++ T4.ddl).create

T1.insert("a")
T2.insert("a")
T3.insert(Some("a"))
T4.insert(Some("a"))

T2.insert(null.asInstanceOf[String])
T3.insert(None)

assertFail { T1.insert(null.asInstanceOf[String]) }
assertFail { T4.insert(None) }
assertEquals(Set(("2", "a"), ("3", "b")), q4.run.toSet)
}

def testLike {
Expand All @@ -85,20 +38,17 @@ class MiscTest extends TestkitTest[JdbcTestDB] {
}

T1.ddl.create
T1.insertAll("foo", "bar", "foobar", "foo%")
T1 ++= Seq("foo", "bar", "foobar", "foo%")

val q1 = for { t1 <- T1 if t1.a like "foo" } yield t1.a
println("q1: " + q1.selectStatement)
assertEquals(List("foo"), q1.list)
assertEquals(List("foo"), q1.run)

val q2 = for { t1 <- T1 if t1.a like "foo%" } yield t1.a
println("q2: " + q2.selectStatement)
assertEquals(Set("foo", "foobar", "foo%"), q2.to[Set])
assertEquals(Set("foo", "foobar", "foo%"), q2.run.toSet)

ifCap(rcap.likeEscape) {
val q3 = for { t1 <- T1 if t1.a.like("foo^%", '^') } yield t1.a
println("q3: " + q3.selectStatement)
assertEquals(Set("foo%"), q3.to[Set])
assertEquals(List("foo%"), q3.run)
}
}

Expand All @@ -111,7 +61,7 @@ class MiscTest extends TestkitTest[JdbcTestDB] {
}

T1.ddl.create
T1.insertAll(("a2", "b2", "c2"), ("a1", "b1", "c1"))
T1 ++= Seq(("a2", "b2", "c2"), ("a1", "b1", "c1"))

implicit class TupledQueryExtensionMethods[E1, E2, U1, U2](q: Query[(E1, E2), (U1, U2)]) {
def sortedValues(implicit ordered: (E1 => scala.slick.lifted.Ordered),
Expand All @@ -123,7 +73,7 @@ class MiscTest extends TestkitTest[JdbcTestDB] {
t1 <- T1
} yield t1.c -> (t1.a, t1.b)).sortedValues

assertEquals(List(("a1", "b1"), ("a2", "b2")), q1.list)
assertEquals(List(("a1", "b1"), ("a2", "b2")), q1.run)
}

def testConditional {
Expand All @@ -133,16 +83,16 @@ class MiscTest extends TestkitTest[JdbcTestDB] {
}

T1.ddl.create
T1.insertAll(1, 2, 3, 4)
T1 ++= Seq(1, 2, 3, 4)

val q1 = T1.map { t1 => (t1.a, Case.If(t1.a < 3) Then 1 Else 0) }
assertEquals(Set((1, 1), (2, 1), (3, 0), (4, 0)), q1.to[Set])
assertEquals(Set((1, 1), (2, 1), (3, 0), (4, 0)), q1.run.toSet)

val q2 = T1.map { t1 => (t1.a, Case.If(t1.a < 3) Then 1) }
assertEquals(Set((1, Some(1)), (2, Some(1)), (3, None), (4, None)), q2.to[Set])
assertEquals(Set((1, Some(1)), (2, Some(1)), (3, None), (4, None)), q2.run.toSet)

val q3 = T1.map { t1 => (t1.a, Case.If(t1.a < 3) Then 1 If(t1.a < 4) Then 2 Else 0) }
assertEquals(Set((1, 1), (2, 1), (3, 2), (4, 0)), q3.to[Set])
assertEquals(Set((1, 1), (2, 1), (3, 2), (4, 0)), q3.run.toSet)
}

def testCast {
Expand All @@ -153,10 +103,10 @@ class MiscTest extends TestkitTest[JdbcTestDB] {
}

T1.ddl.create
T1.insertAll(("foo", 1), ("bar", 2))
T1 ++= Seq(("foo", 1), ("bar", 2))

val q1 = T1.map(t1 => t1.a ++ t1.b.asColumnOf[String])
val r1 = q1.to[Set]
val r1 = q1.run.toSet
assertEquals(Set("foo1", "bar2"), r1)
}

Expand All @@ -168,14 +118,14 @@ class MiscTest extends TestkitTest[JdbcTestDB] {
}

T1.ddl.create
T1.insertAll((1, Some(10)), (2, None))
T1 ++= Seq((1, Some(10)), (2, None))

// GetOrElse in ResultSetMapping on client side
val q1 = for { t <- T1 } yield (t.a, t.b.getOrElse(0))
assertEquals(Set((1, 10), (2, 0)), q1.to[Set])
assertEquals(Set((1, 10), (2, 0)), q1.run.toSet)

// GetOrElse in query on the DB side
val q2 = for { t <- T1 } yield (t.a, t.b.getOrElse(0) + 1)
assertEquals(Set((1, 11), (2, 1)), q2.to[Set])
assertEquals(Set((1, 11), (2, 1)), q2.run.toSet)
}
}
Expand Up @@ -28,7 +28,8 @@ object Testkit {
classOf[tk.JoinTest] ::
classOf[tk.MainTest] ::
classOf[tk.MapperTest] ::
classOf[tk.MiscTest] ::
classOf[tk.RelationalMiscTest] ::
classOf[tk.JdbcMiscTest] ::
classOf[tk.MutateTest] ::
classOf[tk.NestingTest] ::
classOf[tk.NewQuerySemanticsTest] ::
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/scala/slick/driver/AccessDriver.scala
Expand Up @@ -131,7 +131,7 @@ trait AccessDriver extends JdbcDriver { driver =>
case c: ConditionalExpr => {
b"switch("
var first = true
c.clauses.reverseIterator.foreach { case IfThen(l, r) =>
c.clauses.foreach { case IfThen(l, r) =>
if(first) first = false
else b","
b"$l,$r"
Expand Down
Expand Up @@ -271,7 +271,7 @@ trait JdbcStatementBuilderComponent { driver: JdbcDriver =>
b")"
case c: ConditionalExpr =>
b"(case"
c.clauses.reverseIterator.foreach { case IfThen(l, r) => b" when $l then $r" }
c.clauses.foreach { case IfThen(l, r) => b" when $l then $r" }
c.elseClause match {
case LiteralNode(null) =>
case n => b" else $n"
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/scala/slick/lifted/Case.scala
Expand Up @@ -27,7 +27,7 @@ object Case {
}

final class TypedWhen[B : TypedType, T : TypedType](cond: Node, parentClauses: IndexedSeq[Node]) {
def Then(res: Column[T]) = new TypedCase[B,T](new IfThen(cond, Node(res)) +: parentClauses)
def Then(res: Column[T]) = new TypedCase[B,T](parentClauses :+ new IfThen(cond, Node(res)))
}

final class TypedCaseWithElse[T : TypedType](clauses: IndexedSeq[Node], elseClause: Node) extends Column[T] {
Expand Down
51 changes: 48 additions & 3 deletions src/main/scala/scala/slick/memory/QueryInterpreter.scala
Expand Up @@ -6,6 +6,7 @@ import scala.slick.ast._
import scala.slick.SlickException
import scala.slick.util.{SlickLogger, Logging}
import TypeUtil.typeToTypeUtil
import java.util.regex.Pattern

/** A query interpreter for the MemoryDriver and for client-side operations
* that need to be run as part of distributed queries against multiple
Expand Down Expand Up @@ -166,6 +167,19 @@ class QueryInterpreter(db: HeapBackend#Database) extends Logging {
run(ch).asInstanceOf[Option[Any]].getOrElse(default())
case OptionApply(ch) =>
Option(run(ch))
case ConditionalExpr(clauses, elseClause) =>
val opt = n.nodeType.asInstanceOf[ScalaType[_]].nullable
val take = clauses.find { case IfThen(pred, _) => asBoolean(run(pred)) }
take match {
case Some(IfThen(_, r)) =>
val res = run(r)
if(opt && !r.nodeType.asInstanceOf[ScalaType[_]].nullable) Option(res)
else res
case _ =>
val res = run(elseClause)
if(opt && !elseClause.nodeType.asInstanceOf[ScalaType[_]].nullable) Option(res)
else res
}
case Library.Sum(ch) =>
val coll = run(ch).asInstanceOf[Coll]
val (it, itType) = unwrapSingleColumn(coll, ch.nodeType)
Expand Down Expand Up @@ -205,9 +219,9 @@ class QueryInterpreter(db: HeapBackend#Database) extends Logging {
case other => other
}
logDebug("[chPlainV: "+chPlainV.mkString(", ")+"]")
Some(evalFunction(sym, chPlainV))
Some(evalFunction(sym, chPlainV, n.nodeType.asOptionType.elementType))
}
} else evalFunction(sym, chV)
} else evalFunction(sym, chV, n.nodeType)
//case Library.CountAll(ch) => run(ch).asInstanceOf[Coll].size
case l: LiteralNode => l.value
}
Expand All @@ -216,7 +230,7 @@ class QueryInterpreter(db: HeapBackend#Database) extends Logging {
res
}

def evalFunction(sym: Symbol, args: Seq[(Type, Any)]) = sym match {
def evalFunction(sym: Symbol, args: Seq[(Type, Any)], retType: Type) = sym match {
case Library.== => args(0)._2 == args(1)._2
case Library.< => args(0)._1.asInstanceOf[ScalaBaseType[Any]].ordering.lt(args(0)._2, args(1)._2)
case Library.<= => args(0)._1.asInstanceOf[ScalaBaseType[Any]].ordering.lteq(args(0)._2, args(1)._2)
Expand All @@ -226,6 +240,19 @@ class QueryInterpreter(db: HeapBackend#Database) extends Logging {
case Library.Or => args(0)._2.asInstanceOf[Boolean] || args(1)._2.asInstanceOf[Boolean]
case Library.CountAll => args(0)._2.asInstanceOf[Coll].size
case Library.+ => args(0)._1.asInstanceOf[ScalaNumericType[Any]].numeric.plus(args(0)._2, args(1)._2)
case Library.Cast =>
val v = args(0)._2
(args(0)._1, retType) match {
case (a, b) if a == b => v
case (_, ScalaType.stringType) => v.toString
case (_, ScalaType.intType) => v.toString.toInt
case (_, ScalaType.longType) => v.toString.toLong
}
case Library.Concat => args.iterator.map(_._2.toString).mkString
case Library.Like =>
val pat = compileLikePattern(args(1)._2.toString, if(args.length > 2) Some(args(2)._2.toString.charAt(0)) else None)
val mat = pat.matcher(args(0)._2.toString())
mat.matches()
}

def unwrapSingleColumn(coll: Coll, tpe: Type): (Iterator[Any], Type) = tpe.asCollectionType.elementType match {
Expand Down Expand Up @@ -256,6 +283,24 @@ class QueryInterpreter(db: HeapBackend#Database) extends Logging {
case None => false
case null => false
}

def compileLikePattern(s: String, escape: Option[Char]): Pattern = {
val b = new StringBuilder append '^'
val len = s.length
val esc = escape.getOrElse('\0')
var i = 0
while(i < len) {
s.charAt(i) match {
case e if e == esc =>
i += 1
b.append(Pattern.quote(String.valueOf(s.charAt(i))))
case '%' => b.append(".*")
case c => b.append(Pattern.quote(String.valueOf(c)))
}
i += 1
}
Pattern.compile(b.append('$').toString)
}
}

object QueryInterpreter {
Expand Down

0 comments on commit 67582b4

Please sign in to comment.