Permalink
Browse files

Proper rewriting of aggregation functions over explicit joins.

Aggregations over the left or right side of an explicit join were not
properly translated in ConvertToComprehensions.convertSimpleGrouping.

We still have a design problem here: In SQL, a count over a single column
disregards NULL values whereas a count(*) counts all rows of a view. This
distinction should not exist in collection semantics (and it cannot be
implemented in ConvertToComprehensions, long after TableRefExpansions have
been eliminated, since there is no way of telling a count of a single
column from a count of a view of which we only use a single column later).

For now, the fix is to count the first column of the projection. This is
in line with the behaviour in 1.0.0, allows the correct handling of
single-column projections and produces the right result in most cases of
counting joined tables.

We should improve this further when we add proper support for nested
collections.

Fixes issue #135. Test in AggregateTest.testGroupBy.
  • Loading branch information...
szeiger committed Apr 23, 2013
1 parent fa5c251 commit 46bd56851db69530c09d48c2fd232378fef9939a
@@ -118,5 +118,15 @@ class AggregateTest(val tdb: TestDB) extends TestkitTest {
assertEquals(List( ((1,Some(1)),1), ((1,Some(2)),1), ((1,Some(3)),1),
((2,Some(1)),1), ((2,Some(2)),1), ((2,Some(5)),1),
((3,Some(1)),1), ((3,Some(9)),1)), r4)
U.insert(4)
println("=========================================================== q5")
val q5 = (for {
(u, t) <- U leftJoin T on (_.id === _.a)
} yield (u, t)).groupBy(_._1.id).map {
case (id, q) => (id, q.length, q.map(_._1).length, q.map(_._2).length)
}
assertEquals(Set((1, 3, 3, 3), (2, 3, 3, 3), (3, 2, 2, 2), (4, 1, 1, 0)), q5.run.toSet)
}
}
@@ -102,14 +102,24 @@ class ConvertToComprehensions extends Phase {
/** Convert a GroupBy followed by an aggregating map operation to a Comprehension */
def convertSimpleGrouping(gen: Symbol, fromGen: Symbol, from: Node, by: Node, sel: Node): Node = {
object FwdPath {
def apply(ch: List[Symbol]) = Path(ch.reverse)
def unapply(n: Node): Option[List[Symbol]] =
Path.unapply(n).map(_.reverse)
def toString(path: Seq[Symbol]): String = path.mkString("Path ", ".", "")
}
object ProductOfCommonPaths {
def unapply(n: ProductNode): Option[(Symbol, Vector[List[Symbol]])] = if(n.nodeChildren.isEmpty) None else
n.nodeChildren.foldLeft(null: Option[(Symbol, Vector[List[Symbol]])]) {
case (None, _) => None
case (null, FwdPath(sym :: rest)) => Some((sym, Vector(rest)))
case (Some((sym0, v)), FwdPath(sym :: rest)) if sym == sym0 => Some((sym, v :+ rest))
case _ => None
}
}
val newBy = by.replace { case Ref(f) if f == fromGen => Ref(gen) }
val newSel = sel.replace {
case Bind(s1, Select(Ref(gen2), ElementSymbol(2)), Pure(ProductNode(Seq(Select(Ref(s2), field)))))
if (s2 == s1) && (gen2 == gen) => Select(Ref(gen), field)
case Apply(fs, Seq(Bind(s1, Select(Ref(gen2), ElementSymbol(2)), Pure(ProductOfCommonPaths(s2, rests)))))
if (s2 == s1) && (gen2 == gen) =>
Apply(if(fs == Library.CountAll) Library.Count else fs, Seq(FwdPath(gen :: rests.head)))
case Library.CountAll(Select(Ref(gen2), ElementSymbol(2))) if gen2 == gen =>
Library.Count(ConstColumn(1))
case FwdPath(gen2 :: ElementSymbol(idx) :: rest) if gen2 == gen && (idx == 1 || idx == 2) =>

0 comments on commit 46bd568

Please sign in to comment.