Permalink
Browse files

Merge pull request #142 from slick/tmp/aggregate-explicit-join

Proper rewriting of aggregation functions over explicit joins.
  • Loading branch information...
cvogt committed Apr 29, 2013
2 parents ecb56e8 + 46bd568 commit ac50017d1d3851321d08192f68301a853c9a9377
@@ -118,5 +118,15 @@ class AggregateTest(val tdb: TestDB) extends TestkitTest {
assertEquals(List( ((1,Some(1)),1), ((1,Some(2)),1), ((1,Some(3)),1),
((2,Some(1)),1), ((2,Some(2)),1), ((2,Some(5)),1),
((3,Some(1)),1), ((3,Some(9)),1)), r4)
U.insert(4)
println("=========================================================== q5")
val q5 = (for {
(u, t) <- U leftJoin T on (_.id === _.a)
} yield (u, t)).groupBy(_._1.id).map {
case (id, q) => (id, q.length, q.map(_._1).length, q.map(_._2).length)
}
assertEquals(Set((1, 3, 3, 3), (2, 3, 3, 3), (3, 2, 2, 2), (4, 1, 1, 0)), q5.run.toSet)
}
}
@@ -102,14 +102,24 @@ class ConvertToComprehensions extends Phase {
/** Convert a GroupBy followed by an aggregating map operation to a Comprehension */
def convertSimpleGrouping(gen: Symbol, fromGen: Symbol, from: Node, by: Node, sel: Node): Node = {
object FwdPath {
def apply(ch: List[Symbol]) = Path(ch.reverse)
def unapply(n: Node): Option[List[Symbol]] =
Path.unapply(n).map(_.reverse)
def toString(path: Seq[Symbol]): String = path.mkString("Path ", ".", "")
}
object ProductOfCommonPaths {
def unapply(n: ProductNode): Option[(Symbol, Vector[List[Symbol]])] = if(n.nodeChildren.isEmpty) None else
n.nodeChildren.foldLeft(null: Option[(Symbol, Vector[List[Symbol]])]) {
case (None, _) => None
case (null, FwdPath(sym :: rest)) => Some((sym, Vector(rest)))
case (Some((sym0, v)), FwdPath(sym :: rest)) if sym == sym0 => Some((sym, v :+ rest))
case _ => None
}
}
val newBy = by.replace { case Ref(f) if f == fromGen => Ref(gen) }
val newSel = sel.replace {
case Bind(s1, Select(Ref(gen2), ElementSymbol(2)), Pure(ProductNode(Seq(Select(Ref(s2), field)))))
if (s2 == s1) && (gen2 == gen) => Select(Ref(gen), field)
case Apply(fs, Seq(Bind(s1, Select(Ref(gen2), ElementSymbol(2)), Pure(ProductOfCommonPaths(s2, rests)))))
if (s2 == s1) && (gen2 == gen) =>
Apply(if(fs == Library.CountAll) Library.Count else fs, Seq(FwdPath(gen :: rests.head)))
case Library.CountAll(Select(Ref(gen2), ElementSymbol(2))) if gen2 == gen =>
Library.Count(ConstColumn(1))
case FwdPath(gen2 :: ElementSymbol(idx) :: rest) if gen2 == gen && (idx == 1 || idx == 2) =>

0 comments on commit ac50017

Please sign in to comment.