Permalink
Browse files

Use correct discriminator checks under three-valued logic

`expandSums` used to generate discriminator checks of the form
`disc == 1` which do not compose correctly when the discriminator is
null, so we now use `disc is not null` instead.

Fixes #1156. Test in JoinTest.testDiscriminatorCheck.
  • Loading branch information...
szeiger committed Aug 27, 2015
1 parent 723f481 commit 969ec5f23c367e76c1f15725092b0c28b1297d4b
@@ -306,4 +306,29 @@ class JoinTest extends AsyncTest[RelationalTestDB] {
q3.result.named("q3").map(_.toSet shouldBe Set((1,1), (2,2)))
)
}
def testDiscriminatorCheck = {
class A(tag: Tag) extends Table[Int](tag, "a_joinfiltering") {
def id = column[Int]("id")
def * = id
}
lazy val as = TableQuery[A]
class B(tag: Tag) extends Table[Option[Int]](tag, "b_joinfiltering") {
def id = column[Option[Int]]("id")
def * = id
}
lazy val bs = TableQuery[B]
val q = for {
(a, b) <- as joinLeft bs on (_.id.? === _.id) if (b.isEmpty)
} yield (a.id)
DBIO.seq(
(as.schema ++ bs.schema).create,
as ++= Seq(1,2,3),
bs ++= Seq(1,2,4,5).map(Some.apply _),
q.result.map(_.toSet shouldBe Set(3))
)
}
}
@@ -53,16 +53,16 @@ class ExpandSums extends Phase {
// Other OptionFold -> translate to discriminator check
case OptionFold(from, ifEmpty, map, gen) =>
val left = from.select(ElementSymbol(1)).infer()
val pred = Library.==.typed[Boolean](left, Disc1)
val pred = Library.==.typed[Boolean](left, LiteralNode(null))
val n2 = (ifEmpty, map) match {
case (LiteralNode(true), LiteralNode(false)) => Library.Not.typed[Boolean](pred)
case (LiteralNode(false), LiteralNode(true)) => pred
case (LiteralNode(true), LiteralNode(false)) => pred
case (LiteralNode(false), LiteralNode(true)) => Library.Not.typed[Boolean](pred)
case _ =>
val ifDefined = map.replace({
case r @ Ref(s) if s == gen => silentCast(r.nodeType, from.select(ElementSymbol(2)).infer())
}, keepType = true)
val ifEmpty2 = silentCast(ifDefined.nodeType.structural, ifEmpty)
if(left == Disc1) ifDefined else IfThenElse(ConstArray(pred, ifDefined, ifEmpty2))
if(left == Disc1) ifDefined else IfThenElse(ConstArray(Library.Not.typed[Boolean](pred), ifDefined, ifEmpty2))
}
n2.infer()
@@ -224,7 +224,7 @@ class ExpandSums extends Phase {
/** Fuse unnecessary Option operations */
def fuse(n: Node): Node = n match {
// Option.map
case IfThenElse(ConstArray(Library.==(disc, Disc1), ProductNode(ConstArray(Disc1, map)), ProductNode(ConstArray(DiscNone, _)))) =>
case IfThenElse(ConstArray(Library.Not(Library.==(disc, LiteralNode(null))), ProductNode(ConstArray(Disc1, map)), ProductNode(ConstArray(DiscNone, _)))) =>
ProductNode(ConstArray(disc, map)).infer()
case n => n
}

0 comments on commit 969ec5f

Please sign in to comment.