Permalink
Browse files

Allow rewriting of Distinct withinin Aggregate (in addition to Bind)

Test in AggregateTest.testDistinct. Fixes #1325.
  • Loading branch information...
szeiger committed Dec 11, 2015
1 parent 73a0166 commit 86acec55b3768a93e424913bcc533feb55452a83
@@ -328,6 +328,7 @@ class AggregateTest extends AsyncTest[RelationalTestDB] {
val q5a = as.groupBy(_.a).map(_._2.map(_.id).min.get)
val q5b = as.distinct.map(_.id)
val q5c = as.distinct.map(a => (a.id, a.a))
val q6 = as.distinct.length
if(tdb.driver == H2Driver) {
assertNesting(q1a, 1)
@@ -357,7 +358,8 @@ class AggregateTest extends AsyncTest[RelationalTestDB] {
mark("q4", q4.result).map(_.sortBy(identity) shouldBe Seq("a", "a", "c")),
mark("q5a", q5a.result).map(_.sortBy(identity) shouldBe Seq(1, 3)),
mark("q5b", q5b.result).map(_.sortBy(identity) should (r => r == Seq(1, 3) || r == Seq(2, 3))),
mark("q5c", q5c.result).map(_.sortBy(identity) should (r => r == Seq((1, "a"), (3, "c")) || r == Seq((2, "a"), (3, "c"))))
mark("q5c", q5c.result).map(_.sortBy(identity) should (r => r == Seq((1, "a"), (3, "c")) || r == Seq((2, "a"), (3, "c")))),
mark("q6", q6.result).map(_ shouldBe 2)
)
}
}
@@ -12,47 +12,60 @@ class RewriteDistinct extends Phase {
def apply(state: CompilerState) = if(state.get(Phase.assignUniqueSymbols).map(_.distinct).getOrElse(true)) state.map(_.replace({
case n @ Bind(s1, Distinct(s2, from1, on1), Pure(sel1, ts1)) =>
logger.debug("Rewriting Distinct:", Ellipsis(n, List(0, 0)))
val refFields = sel1.collect[TermSymbol] {
case Select(Ref(s), f) if s == s1 => f
}.toSet
logger.debug("Referenced fields: " + refFields.mkString(", "))
val onFlat = ProductNode(ConstArray(on1)).flatten
val onNodes = onFlat.children.toSet
val onFieldPos = onNodes.iterator.zipWithIndex.collect[(TermSymbol, Int)] {
case (Select(Ref(s), f), idx) if s == s2 => (f, idx)
case n @ Bind(s1, dist1: Distinct, Pure(sel1, ts1)) =>
logger.debug("Rewriting Distinct in Bind:", Ellipsis(n, List(0, 0)))
val (inner, sel2) = rewrite(s1, dist1, sel1)
Bind(s1, inner, Pure(sel2, ts1)).infer()
case n @ Aggregate(s1, dist1: Distinct, sel1) =>
logger.debug("Rewriting Distinct in Aggregate:", Ellipsis(n, List(0, 0)))
val (inner, sel2) = rewrite(s1, dist1, sel1)
Aggregate(s1, inner, sel2).infer()
}, keepType = true, bottomUp = true)) else {
logger.debug("No DISTINCT used as determined by assignUniqueSymbols - skipping phase")
state
}
def rewrite(s1: TermSymbol, dist1: Distinct, sel1: Node): (Node, Node) = {
val refFields = sel1.collect[TermSymbol] {
case Select(Ref(s), f) if s == s1 => f
}.toSet
logger.debug("Referenced fields: " + refFields.mkString(", "))
val onFlat = ProductNode(ConstArray(dist1.on)).flatten
val onNodes = onFlat.children.toSet
val onFieldPos = onNodes.iterator.zipWithIndex.collect[(TermSymbol, Int)] {
case (Select(Ref(s), f), idx) if s == dist1.generator => (f, idx)
}.toMap
logger.debug("Fields used directly in 'on' clause: " + onFieldPos.keySet.mkString(", "))
if((refFields -- onFieldPos.keys).isEmpty) {
// Only distinct fields referenced -> Create subquery and remove 'on' clause
val onDefs = ConstArray.from(onNodes).map((new AnonSymbol, _))
val onLookup = onDefs.iterator.collect[(TermSymbol, AnonSymbol)] {
case (a, Select(Ref(s), f)) if s == dist1.generator => (f, a)
}.toMap
logger.debug("Fields used directly in 'on' clause: " + onFieldPos.keySet.mkString(", "))
if((refFields -- onFieldPos.keys).isEmpty) {
// Only distinct fields referenced -> Create subquery and remove 'on' clause
val onDefs = ConstArray.from(onNodes).map((new AnonSymbol, _))
val onLookup = onDefs.iterator.collect[(TermSymbol, AnonSymbol)] {
case (a, Select(Ref(s), f)) if s == s2 => (f, a)
}.toMap
val inner = Bind(s2, Distinct(new AnonSymbol, from1, ProductNode(ConstArray.empty)), Pure(StructNode(onDefs)))
val sel2 = sel1.replace {
case Select(Ref(s), f) if s == s1 => Select(Ref(s), onLookup(f))
}
val ret = Bind(s1, Subquery(inner, Subquery.AboveDistinct), Pure(sel2, ts1)).infer()
logger.debug("Removed 'on' clause from Distinct:", Ellipsis(ret, List(0, 0, 0, 0)))
ret
} else {
val sel2 = sel1.replace {
case Select(Ref(s), f) :@ tpe if s == s1 =>
onFieldPos.get(f) match {
case Some(idx) =>
Select(Select(Ref(s), ElementSymbol(1)), ElementSymbol(idx+1))
case None =>
val as = new AnonSymbol
Aggregate(as, Select(Ref(s), ElementSymbol(2)),
Library.Min.typed(tpe, Select(Ref(as), f)))
}
}
val ret = Bind(s1, GroupBy(s2, from1, onFlat), Pure(sel2, ts1)).infer()
logger.debug("Transformed Distinct to GroupBy:", Ellipsis(ret, List(0, 0)))
ret
val inner = Bind(dist1.generator, Distinct(new AnonSymbol, dist1.from, ProductNode(ConstArray.empty)), Pure(StructNode(onDefs)))
val sel2 = sel1.replace {
case Select(Ref(s), f) if s == s1 => Select(Ref(s), onLookup(f))
}
}, keepType = true, bottomUp = true)) else state
val ret = Subquery(inner, Subquery.AboveDistinct)
logger.debug("Removed 'on' clause from Distinct:", Ellipsis(ret, List(0, 0, 0)))
(ret, sel2)
} else {
val sel2 = sel1.replace {
case Select(Ref(s), f) :@ tpe if s == s1 =>
onFieldPos.get(f) match {
case Some(idx) =>
Select(Select(Ref(s), ElementSymbol(1)), ElementSymbol(idx+1))
case None =>
val as = new AnonSymbol
Aggregate(as, Select(Ref(s), ElementSymbol(2)),
Library.Min.typed(tpe, Select(Ref(as), f)))
}
}
val ret = GroupBy(dist1.generator, dist1.from, onFlat)
logger.debug("Transformed Distinct to GroupBy:", Ellipsis(ret, List(0)))
(ret, sel2)
}
}
}

0 comments on commit 86acec5

Please sign in to comment.