Skip to content
Permalink
Browse files

Fix lifting of aggregation functions which reference other generators.

Also use only new-style aggregations in the test cases. The old-style
aggregations will be removed (on master) and deprecated (on 1.0) in
separate commits.

Fixes issue #61.
(cherry picked from commit df40f1d)

Conflicts:

	slick-testkit/src/main/scala/com/typesafe/slick/testkit/tests/OldTest.scala
	src/main/scala/scala/slick/compiler/Relational.scala
	src/main/scala/scala/slick/lifted/Query.scala
  • Loading branch information
szeiger committed Nov 12, 2012
1 parent 9ee64a3 commit 9e97d5e6d964b0943db645d667903ca74a3cd692
@@ -87,12 +87,11 @@ class MainTest(val tdb: TestDB) extends TestkitTest {
println("All Orders by Users with a last name by first name:")
q3.foreach(o => println(" "+o))

val q4 = for (
u <- Users;
o <- Orders
if (o.orderID in (for { o2 <- Orders where(o.userID is _.userID) } yield o2.orderID.max))
&& (o.userID is u.id)
) yield u.first ~ o.orderID
val q4 = for {
u <- Users
o <- u.orders
if (o.orderID === (for { o2 <- Orders where(o.userID is _.userID) } yield o2.orderID).max)
} yield u.first ~ o.orderID
println("q4: " + q4.selectStatement)
println("Latest Order per User:")
q4.foreach(o => println(" "+o))
@@ -102,7 +101,7 @@ class MainTest(val tdb: TestDB) extends TestkitTest {

def maxOfPer[T <: Table[_]]
(c: T, m: (T => Column[Int]), p: (T => Column[Int])) =
c where { o => m(o) in (for { o2 <- c if p(o) is p(o2) } yield m(o2).max) }
c where { o => m(o) is (for { o2 <- c if p(o) is p(o2) } yield m(o2)).max }

val q4b = for (
u <- Users;
@@ -53,22 +53,22 @@ class OldTest(val tdb: TestDB) extends TestkitTest {
} yield u.first ~ o.orderID

val q5 = for (
o <- for (o <- Orders if o.orderID in (for {o2 <- Orders if o.userID is o2.userID} yield o2.orderID.max)) yield o.orderID;
o <- for (o <- Orders if o.orderID === (for {o2 <- Orders if o.userID is o2.userID} yield o2.orderID).max) yield o.orderID;
_ <- Query orderBy o
) yield o

val q6a = for (
o <- (for (o <- Orders if o.orderID in (for {o2 <- Orders if o.userID is o2.userID} yield o2.orderID.max)) yield o.orderID).sub;
o <- (for (o <- Orders if o.orderID === (for {o2 <- Orders if o.userID is o2.userID} yield o2.orderID).max) yield o.orderID);
_ <- Query orderBy o
) yield o

val q6b = for (
o <- (for (o <- Orders if o.orderID in (for {o2 <- Orders if o.userID is o2.userID} yield o2.orderID.max)) yield o.orderID ~ o.userID).sub;
o <- (for (o <- Orders if o.orderID === (for {o2 <- Orders if o.userID is o2.userID} yield o2.orderID).max) yield o.orderID ~ o.userID);
_ <- Query orderBy o._1
) yield o

val q6c = for (
o <- (for (o <- Orders if o.orderID in (for {o2 <- Orders if o.userID is o2.userID} yield o2.orderID.max)) yield o).sub;
o <- (for (o <- Orders if o.orderID === (for {o2 <- Orders if o.userID is o2.userID} yield o2.orderID).max) yield o);
_ <- Query orderBy o.orderID
) yield o.orderID ~ o.userID

@@ -47,7 +47,7 @@ object Benchmark {
} yield u.first ~ o.orderID
val q5 = for (
o <- Orders
where { o => o.orderID in (for { o2 <- Orders where(o.userID is _.userID) } yield o2.orderID.max) }
where { o => o.orderID === (for { o2 <- Orders where(o.userID is _.userID) } yield o2.orderID).max }
) yield o.orderID

val s1 = BasicDriver.buildSelectStatement(q1)
@@ -119,4 +119,10 @@ object ExtraUtil {
case r: RowNumber => f(r)
case n => n.nodeMapChildren(ch => replaceRowNumber(ch)(f))
}

def hasRefToOneOf(n: Node, s: scala.collection.Set[Symbol]): Boolean = n match {
case r: RefNode =>
r.nodeReferences.exists(sym => s.contains(sym)) || n.nodeChildren.exists(ch => hasRefToOneOf(ch, s))
case n => n.nodeChildren.exists(ch => hasRefToOneOf(ch, s))
}
}
@@ -6,6 +6,7 @@ import scala.slick.SlickException
import scala.slick.lifted.ConstColumn
import scala.slick.ast._
import Util._
import ExtraUtil._

/** Rewrite zip joins into a form suitable for SQL (using inner joins and
* RowNumber columns.
@@ -139,13 +140,12 @@ class FuseComprehensions extends Phase {

def fuse(n: Node): Node = n.nodeMapChildren(fuse) match {
case c: Comprehension =>
logger.debug("Checking:",c)
val fused = createSelect(c) match {
case c2: Comprehension if isFuseableOuter(c2) => fuseComprehension(c2)
case c2 => c2
}
val f2 = liftAggregates(fused)
//if(f2 eq fused) f2 else fuse(f2)
f2
liftAggregates(fused)
case n => n
}

@@ -237,39 +237,61 @@ class FuseComprehensions extends Phase {
else c
}

/** Lift aggregates of sub-queries into the 'from' list. */
/** Lift aggregates of sub-queries into the 'from' list or inline them
* (if they would refer to unreachable symbols when used in 'from'
* position). */
def liftAggregates(c: Comprehension): Comprehension = {
val lift = ArrayBuffer[(AnonSymbol, AnonSymbol, Library.AggregateFunctionSymbol, Comprehension)]()
val seenGens = HashMap[Symbol, Node]()
def tr(n: Node): Node = n match {
//TODO Once we can recognize structurally equivalent sub-queries and merge them, c2 could be a Ref
case Apply(s: Library.AggregateFunctionSymbol, Seq(c2: Comprehension)) =>
val a = new AnonSymbol
val f = new AnonSymbol
lift += ((a, f, s, c2))
Select(Ref(a), f)
case c: Comprehension => c // don't recurse into sub-queries
case n => n.nodeMapChildren(tr)
}
if(c.select.isEmpty) c else {
val sel = c.select.get
val sel2 = tr(sel)
if(lift.isEmpty) c else {
val newFrom = lift.map { case (a, f, s, c2) =>
val a2 = new AnonSymbol
val (c2b, call) = s match {
case ap @ Apply(s: Library.AggregateFunctionSymbol, Seq(c2: Comprehension)) =>
if(hasRefToOneOf(c2, seenGens.keySet)) {
logger.debug("Seen reference to one of {"+seenGens.keys.mkString(", ")+"} in "+c2+" -- inlining")
// This could still produce illegal SQL code if the reference is nested within another
// sub-query somewhere in 'from' position. Not much we can do about this though.
s match {
case Library.CountAll =>
(c2, Library.Count(ConstColumn(1)))
c2.copy(select = Some(Pure(ProductNode(Seq(Library.Count(ConstColumn(1)))))))
case s =>
val c3 = ensureStruct(c2)
// All standard aggregate functions operate on a single column
val Some(Pure(StructNode(Seq((f2, _))))) = c3.select
(c3, Apply(s, Seq(Select(Ref(a2), f2))))
val Some(Pure(StructNode(Seq((f2, expr))))) = c3.select
c3.copy(select = Some(Pure(ProductNode(Seq(Apply(s, Seq(expr)))))))
}
a -> Comprehension(from = Seq(a2 -> c2b),
select = Some(Pure(StructNode(IndexedSeq(f -> call)))))
} else {
val a = new AnonSymbol
val f = new AnonSymbol
lift += ((a, f, s, c2))
Select(Ref(a), f)
}
case c: Comprehension => c // don't recurse into sub-queries
case n => n.nodeMapChildren(tr)
}
val c2 = c.nodeMapScopedChildren {
case (Some(gen), ch) =>
seenGens += gen -> ch
ch
case (None, ch) => tr(ch)
}.asInstanceOf[Comprehension]
if(lift.isEmpty) c2
else {
val newFrom = lift.map { case (a, f, s, c2) =>
val a2 = new AnonSymbol
val (c2b, call) = s match {
case Library.CountAll =>
(c2, Library.Count(ConstColumn(1)))
case s =>
val c3 = ensureStruct(c2)
// All standard aggregate functions operate on a single column
val Some(Pure(StructNode(Seq((f2, _))))) = c3.select
(c3, Apply(s, Seq(Select(Ref(a2), f2))))
}
c.copy(from = c.from ++ newFrom, select = Some(sel2))
a -> Comprehension(from = Seq(a2 -> c2b),
select = Some(Pure(StructNode(IndexedSeq(f -> call)))))
}
logger.debug("Introducing new generator(s) "+newFrom.map(_._1).mkString(", ")+" for aggregations")
c2.copy(from = c.from ++ newFrom)
}
}

@@ -86,6 +86,7 @@ abstract class Query[+E, U] extends Rep[Seq[U]] with CollectionLinearizer[Seq, U
def length: Column[Int] = Library.CountAll.column(Node(unpackable.value))
@deprecated("Use .length instead of .count", "0.10.0-M2")
def count = length
def countDistinct: Column[Int] = Library.CountDistinct.column(Node(unpackable.value))
def exists = Library.Exists.column[Boolean](Node(unpackable.value))

@deprecated("Query.sub is not needed anymore", "0.10.0-M2")

0 comments on commit 9e97d5e

Please sign in to comment.
You can’t perform that action at this time.