From 1c1bf30b5265cc33cfb85156d34d55f53457dd9b Mon Sep 17 00:00:00 2001 From: Stefan Zeiger Date: Tue, 8 Oct 2013 22:26:58 +0200 Subject: [PATCH] Fuse simple mappings before resolving GroupBy. This enables the use of multiple mapping steps for extracting aggregated values from groups. Test in AggregateTest.testMultiMapAggregates. Fixes issues #186, #187, #189. --- .../slick/testkit/tests/AggregateTest.scala | 59 +++++++++++++++++++ src/main/scala/scala/slick/ast/Util.scala | 2 +- .../scala/slick/compiler/Relational.scala | 33 ++++++++--- 3 files changed, 84 insertions(+), 10 deletions(-) diff --git a/slick-testkit/src/main/scala/com/typesafe/slick/testkit/tests/AggregateTest.scala b/slick-testkit/src/main/scala/com/typesafe/slick/testkit/tests/AggregateTest.scala index 2dd3adcc1f..1120bd08bb 100644 --- a/slick-testkit/src/main/scala/com/typesafe/slick/testkit/tests/AggregateTest.scala +++ b/slick-testkit/src/main/scala/com/typesafe/slick/testkit/tests/AggregateTest.scala @@ -155,4 +155,63 @@ class AggregateTest extends TestkitTest[RelationalTestDB] { } assertEquals(Set(("baz","quux",Some(4)), ("foo","quux",Some(3)), ("foo","bar",Some(3))), q1.run.toSet) } + + def testMultiMapAggregates { + class B(tag: Tag) extends Table[(Long, String, String)](tag, "b_multimap") { + def id = column[Long]("id", O.PrimaryKey) + def b = column[String]("b") + def d = column[String]("d") + + def * = (id, b, d) + } + val bs = TableQuery[B] + class A(tag: Tag) extends Table[(Long, String, Long, Long)](tag, "a_multimap") { + def id = column[Long]("id", O.PrimaryKey) + def a = column[String]("a") + def c = column[Long]("c") + def fkId = column[Long]("fkId") + def * = (id, a, c, fkId) + } + val as = TableQuery[A] + (as.ddl ++ bs.ddl).create + + val q1 = as.groupBy(_.id).map(_._2.map(x => x).map(x => x.a).min) + assert(q1.run.toList.isEmpty) + + val q2 = + (as leftJoin bs on (_.id === _.id)).map { case (c, s) => + val name = s.b + (c, s, name) + }.groupBy { prop => + val c = prop._1 + val s = prop._2 + val name = prop._3 + s.id + }.map { prop => + val supId = prop._1 + val c = prop._2.map(x => x._1) + val s = prop._2.map(x => x._2) + val name = prop._2.map(x => x._3) + (name.min, s.map(_.b).min, supId, c.length) + } + assert(q2.run.isEmpty) + + val q4 = as.flatMap { t1 => + bs.withFilter { t2 => + t1.fkId === t2.id && t2.d === "" + }.map(t2 => (t1, t2)) + }.groupBy { prop => + val t1 = prop._1 + val t2 = prop._2 + (t1.a, t2.b) + }.map { prop => + val a = prop._1._1 + val b = prop._1._2 + val t1 = prop._2.map(_._1) + val t2 = prop._2.map(_._2) + val c3 = t1.map(_.c).max + scala.Tuple3(a, b, c3) + } + assert(q4.run.isEmpty) + } } diff --git a/src/main/scala/scala/slick/ast/Util.scala b/src/main/scala/scala/slick/ast/Util.scala index d7c34cb396..b4d3144d3e 100644 --- a/src/main/scala/scala/slick/ast/Util.scala +++ b/src/main/scala/scala/slick/ast/Util.scala @@ -89,7 +89,7 @@ object NodeOps { } def replace(tree: Node, f: PartialFunction[Node, Node], keepType: Boolean): Node = - f.applyOrElse(tree, ({ case n: Node => n.nodeMapChildren(_.replace(f), keepType) }): PartialFunction[Node, Node]) + f.applyOrElse(tree, ({ case n: Node => n.nodeMapChildren(_.replace(f, keepType), keepType) }): PartialFunction[Node, Node]) } /** Some less general but still useful methods for the code generators. */ diff --git a/src/main/scala/scala/slick/compiler/Relational.scala b/src/main/scala/scala/slick/compiler/Relational.scala index 6a3860e939..f13022a018 100644 --- a/src/main/scala/scala/slick/compiler/Relational.scala +++ b/src/main/scala/scala/slick/compiler/Relational.scala @@ -74,7 +74,30 @@ class ConvertToComprehensions extends Phase { case n => Seq((s, n)) } - def convert(n: Node): Node = (n.nodeMapChildren(convert, keepType = true) match { + def convert(n: Node): Node = convert1(n.nodeMapChildren(convert, keepType = true)) match { + case c1 @ Comprehension(from1, where1, None, orderBy1, + Some(c2 @ Comprehension(from2, where2, None, orderBy2, select, None, None)), + fetch, offset) => + c2.copy(from = from1 ++ from2, where = where1 ++ where2, + orderBy = orderBy2 ++ orderBy1, fetch = fetch, offset = offset + ).nodeTyped(c1.nodeType) + case n => n + } + + def convert1(n: Node): Node = n match { + // Fuse simple mappings. This enables the use of multiple mapping steps + // for extracting aggregated values from groups. We have to do it here + // because Comprehension fusion comes after the special rewriting that + // we have to do for GroupBy aggregation. + case Bind(ogen, Comprehension(Seq((igen, from)), Nil, None, Nil, Some(Pure(isel, _)), None, None), Pure(osel, oident)) => + logger.debug("Fusing simple mapping:", n) + val sel = osel.replace({ + case FwdPath(base :: rest) if base == ogen => + rest.foldLeft(isel)(_ select _) + }, keepType = true) + val res = Bind(igen, from, Pure(sel, oident).nodeTyped(n.nodeType)).nodeTyped(n.nodeType) + logger.debug("Fused to:", res) + convert1(res) // Table to Comprehension case t: TableNode => val gen = new AnonSymbol @@ -120,14 +143,6 @@ class ConvertToComprehensions extends Phase { else Comprehension(from = mkFrom(new AnonSymbol, from), fetch = take.map(_.toLong), offset = drop2.map(_.toLong)) c.nodeTyped(td.nodeType) case n => n - }) match { - case c1 @ Comprehension(from1, where1, None, orderBy1, - Some(c2 @ Comprehension(from2, where2, None, orderBy2, select, None, None)), - fetch, offset) => - c2.copy(from = from1 ++ from2, where = where1 ++ where2, - orderBy = orderBy2 ++ orderBy1, fetch = fetch, offset = offset - ).nodeTyped(c1.nodeType) - case n => n } /** An extractor for nested Take and Drop nodes */