Permalink
Browse files

Ensure non-empty “select” clauses in mergeToComprehensions:

Am empty select is filled with a dummy `LiteralNode(1)` to ensure
proper translation of such a Comprehension in `removeFieldNames` and
the code generator.

We also add an option to `RemoveFieldNames` for keeping StructNodes
intact when used in a “from” clause of a Comprehension, even when there
are no references. This can be used to force the creation of aliases
for dummy nodes in the code generator when required by a database.
  • Loading branch information...
szeiger committed Sep 9, 2015
1 parent f82a8fe commit ebfb344de4c899163c87cf971dfdf3e9fc656bbc
@@ -15,13 +15,13 @@ class CountTest extends AsyncTest[RelationalTestDB] {
_ <- testTable.schema.create
_ <- testTable ++= Seq(1, 2, 3, 4, 5)
q1 = Query(testTable.length)
_ <- q1.result.map(_ shouldBe Vector(5))
_ <- mark("q1", q1.result).map(_ shouldBe Vector(5))
q2 = testTable.length
_ <- q2.result.map(_ shouldBe 5)
_ <- mark("q2", q2.result).map(_ shouldBe 5)
q3 = testTable.filter(_.id < 3).length
_ <- q3.result.map(_ shouldBe 2)
_ <- mark("q3", q3.result).map(_ shouldBe 2)
q4 = testTable.take(2).length
_ <- q4.result.map(_ shouldBe 2)
_ <- mark("q4", q4.result).map(_ shouldBe 2)
} yield ()
}
@@ -232,7 +232,13 @@ class MergeToComprehensions extends Phase {
val (c, rep) = mergeTakeDrop(n, false)
val mappings = ConstArray.from(rep.mapValues(_ :: Nil))
logger.debug("Mappings are: "+mappings)
(c, mappings)
val c2 = c.select match {
// Ensure that the select clause is non-empty
case Pure(StructNode(ConstArray.empty), _) =>
c.copy(select = Pure(StructNode(ConstArray((new AnonSymbol, LiteralNode(1)))))).infer()
case _ => c
}
(c2, mappings)
}
def convert1(n: Node): Node = n match {
@@ -3,10 +3,11 @@ package slick.compiler
import slick.ast._
import Util._
import TypeUtil._
import slick.util.ConstArray
/** Convert unreferenced StructNodes to single columns or ProductNodes (which is needed for
* aggregation functions and at the top level). */
class RemoveFieldNames extends Phase {
class RemoveFieldNames(val alwaysKeepSubqueryNames: Boolean = false) extends Phase {
val name = "removeFieldNames"
def apply(state: CompilerState) = state.map { n => ClientSideOp.mapResultSetMapping(n, true) { rsm =>
@@ -16,10 +17,14 @@ class RemoveFieldNames extends Phase {
val refTSyms = n.collect[TypeSymbol] {
case Select(_ :@ NominalType(s, _), _) => s
case Union(_, _ :@ CollectionType(_, NominalType(s, _)), _) => s
case Comprehension(_, _ :@ CollectionType(_, NominalType(s, _)), _, _, _, _, _, _, _, _) if alwaysKeepSubqueryNames => s
}.toSet
val allTSyms = n.collect[TypeSymbol] { case p: Pure => p.identity }.toSet
val unrefTSyms = allTSyms -- refTSyms
n.replaceInvalidate {
case Pure(StructNode(ConstArray.empty), pts) =>
// Always convert an empty StructNode because there is nothing to reference
(Pure(ProductNode(ConstArray.empty), pts), pts)
case Pure(StructNode(ch), pts) if unrefTSyms contains pts =>
(Pure(if(ch.length == 1 && pts != top) ch(0)._2 else ProductNode(ch.map(_._2)), pts), pts)
}.infer()

0 comments on commit ebfb344

Please sign in to comment.