Permalink
Browse files

Add `Query.distinct` and `Query.distinctOn` operators

- `distinctOn` is encoded in a new `Distinct` node type in the AST.
  `q.distinct` is equivalent to `q.distinctOn(identity)`.

- `Comprehension` gets a new field for an optional “distinct on” Node.
  A plain “distinct” is encoded as “distinct on ()” (an empty
  ProductNode).

- In `forceOuterBinds` all `Distinct` nodes are wrapped (like `Join`
  and `Pure`).

- In `expandTables` table types are expanded into their star projections
  in “distinct on” clauses so that distinctness is defined the same way
  as it would be if the operation was performed in Scala after getting
  the non-distinct results.

- A new phase `rewriteDistinct` after `pruneProjections` removes
  “distinct on” and replaces it by a simple “distinct” (with an extra
  subquery where necessary) or falls back to a “group by” aggregation
  when “distinct” is not possible.

- In `reorderOperations` only distinctness-preserving aliasing Binds can
  be pushed down into `Subquery.AboveDistinct` boundaries. Mappings are
  recognized as distinctness-preserving if they are purely aliasing and
  do not drop any fields from the original type.

- In `mergeToComprehensions` the new `Distinct` nodes are merged. This
  complicates things somewhat because the rules when something can be
  merged are not static anymore. We merge Distinct together with SortBy
  but enforce (both in `mergeSortBy` and in `mergeCommon`) that DISTINCT
  is not created on top of an existing DISTINCT or HAVING and that
  WHERE and HAVING are not created on top of an existing DISTINCT.

- JdbcStatementBuilderComponent produces the default SQL syntax for
  “distinct” and the PostgreSQL syntax for “distinct on”.

- `PostgresDriver` omits the `rewriteDistinct` phase, using a local
  optimization in the code generator instead to emit “on” clauses where
  possible without the risk of creating extra subqueries.

- Add support for `Distinct` to `QueryInterpreter`.

Fixes #96. Tests in AggregateTest.testDistinct.
  • Loading branch information...
szeiger committed Sep 3, 2015
1 parent e933ed3 commit 9d8bbee624321a1f51dafb216cd9b490489f2cb2
Showing with 361 additions and 100 deletions.
  1. +1 −0 common-test-resources/logback.xml
  2. +1 −0 slick-testkit/src/doctest/resources/logback.xml
  3. +54 −0 slick-testkit/src/main/scala/com/typesafe/slick/testkit/tests/AggregateTest.scala
  4. +0 −14 slick-testkit/src/main/scala/com/typesafe/slick/testkit/tests/NewQuerySemanticsTest.scala
  5. +17 −3 slick-testkit/src/main/scala/com/typesafe/slick/testkit/util/Testkit.scala
  6. +12 −5 slick/src/main/scala/slick/ast/Comprehension.scala
  7. +31 −15 slick/src/main/scala/slick/ast/Node.scala
  8. +7 −2 slick/src/main/scala/slick/ast/Type.scala
  9. +28 −17 slick/src/main/scala/slick/compiler/ExpandTables.scala
  10. +2 −2 slick/src/main/scala/slick/compiler/ForceOuterBinds.scala
  11. +45 −25 slick/src/main/scala/slick/compiler/MergeToComprehensions.scala
  12. +2 −0 slick/src/main/scala/slick/compiler/QueryCompiler.scala
  13. +15 −2 slick/src/main/scala/slick/compiler/ReorderOperations.scala
  14. +1 −1 slick/src/main/scala/slick/compiler/RewriteBooleans.scala
  15. +58 −0 slick/src/main/scala/slick/compiler/RewriteDistinct.scala
  16. +2 −2 slick/src/main/scala/slick/compiler/SpecializeParameters.scala
  17. +1 −1 slick/src/main/scala/slick/compiler/VerifySymbols.scala
  18. +10 −4 slick/src/main/scala/slick/driver/JdbcStatementBuilderComponent.scala
  19. +1 −1 slick/src/main/scala/slick/driver/MySQLDriver.scala
  20. +18 −2 slick/src/main/scala/slick/driver/PostgresDriver.scala
  21. +14 −0 slick/src/main/scala/slick/lifted/Query.scala
  22. +15 −0 slick/src/main/scala/slick/memory/QueryInterpreter.scala
  23. +21 −0 slick/src/main/scala/slick/util/ConstArray.scala
  24. +5 −4 slick/src/main/scala/slick/util/SQLBuilder.scala
@@ -34,6 +34,7 @@
<logger name="slick.compiler.OptimizeScalar" level="${log.qcomp.optimizeScalar:-inherited}" />
<logger name="slick.compiler.FixRowNumberOrdering" level="${log.qcomp.fixRowNumberOrdering:-inherited}" />
<logger name="slick.compiler.PruneProjections" level="${log.qcomp.pruneProjections:-inherited}" />
<logger name="slick.compiler.RewriteDistinct" level="${log.qcomp.rewriteDistinct:-inherited}" />
<logger name="slick.compiler.RewriteBooleans" level="${log.qcomp.rewriteBooleans:-inherited}" />
<logger name="slick.compiler.SpecializeParameters" level="${log.qcomp.specializeParameters:-inherited}" />
<logger name="slick.compiler.CodeGen" level="${log.qcomp.codeGen:-inherited}" />
@@ -34,6 +34,7 @@
<logger name="slick.compiler.OptimizeScalar" level="${log.qcomp.optimizeScalar:-inherited}" />
<logger name="slick.compiler.FixRowNumberOrdering" level="${log.qcomp.fixRowNumberOrdering:-inherited}" />
<logger name="slick.compiler.PruneProjections" level="${log.qcomp.pruneProjections:-inherited}" />
<logger name="slick.compiler.RewriteDistinct" level="${log.qcomp.rewriteDistinct:-inherited}" />
<logger name="slick.compiler.RewriteBooleans" level="${log.qcomp.rewriteBooleans:-inherited}" />
<logger name="slick.compiler.SpecializeParameters" level="${log.qcomp.specializeParameters:-inherited}" />
<logger name="slick.compiler.CodeGen" level="${log.qcomp.codeGen:-inherited}" />
@@ -1,6 +1,7 @@
package com.typesafe.slick.testkit.tests
import com.typesafe.slick.testkit.util.{AsyncTest, RelationalTestDB}
import slick.driver.{H2Driver, PostgresDriver}
class AggregateTest extends AsyncTest[RelationalTestDB] {
import tdb.profile.api._
@@ -280,4 +281,57 @@ class AggregateTest extends AsyncTest[RelationalTestDB] {
_ <- q4.result.map(_ shouldBe Nil)
} yield ()
}
def testDistinct = {
class A(tag: Tag) extends Table[String](tag, "A_DISTINCT") {
def id = column[Int]("id", O.PrimaryKey)
def a = column[String]("a")
def b = column[String]("b")
def * = a
override def create_* = collectFieldSymbols((id, a, b).shaped.toNode)
}
val as = TableQuery[A]
val data = Set((1, "a", "a"), (2, "a", "b"), (3, "c", "b"))
val q1a = as.map(_.a).distinct
val q1b = as.distinct.map(_.a)
val q2 = as.distinct.map(a => (a.a, 5))
val q3a = as.distinct.map(_.id).filter(_ === 1) unionAll as.distinct.map(_.id).filter(_ === 2)
val q4 = as.map(a => (a.a, a.b)).distinct.map(_._1)
val q5a = as.groupBy(_.a).map(_._2.map(_.id).min.get)
val q5b = as.distinct.map(_.id)
val q5c = as.distinct.map(a => (a.id, a.a))
if(tdb.driver == H2Driver) {
assertNesting(q1a, 1)
assertNesting(q1b, 1)
assertNesting(q3a, 2)
assertNesting(q4, 2)
assertNesting(q5a, 1)
assertNesting(q5b, 1)
assertNesting(q5c, 1)
} else if(tdb.driver == PostgresDriver) {
assertNesting(q1a, 1)
assertNesting(q1b, 1)
assertNesting(q3a, 4)
assertNesting(q4, 1)
assertNesting(q5a, 1)
assertNesting(q5b, 1)
assertNesting(q5c, 1)
}
DBIO.seq(
as.schema.create,
as.map(a => (a.id, a.a, a.b)) ++= data,
mark("q1a", q1a.result).map(_.sortBy(identity) shouldBe Seq("a", "c")),
mark("q1b", q1b.result).map(_.sortBy(identity) shouldBe Seq("a", "c")),
mark("q2", q2.result).map(_.sortBy(identity) shouldBe Seq(("a", 5), ("c", 5))),
mark("q3a", q3a.result).map(_ should (r => r == Seq(1) || r == Seq(2))),
mark("q4", q4.result).map(_.sortBy(identity) shouldBe Seq("a", "a", "c")),
mark("q5a", q5a.result).map(_.sortBy(identity) shouldBe Seq(1, 3)),
mark("q5b", q5b.result).map(_.sortBy(identity) should (r => r == Seq(1, 3) || r == Seq(2, 3))),
mark("q5c", q5c.result).map(_.sortBy(identity) should (r => r == Seq((1, "a"), (3, "c")) || r == Seq((2, "a"), (3, "c"))))
)
}
}
@@ -1,9 +1,7 @@
package com.typesafe.slick.testkit.tests
import slick.SlickTreeException
import slick.driver.H2Driver
import scala.language.higherKinds
import com.typesafe.slick.testkit.util.{RelationalTestDB, AsyncTest}
class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
@@ -569,16 +567,4 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
_ <- mark("q19", q19.result).map(_.toSet shouldBe Set(Some((1,"a","a")), Some((2,"a","b")), Some((3,"c","b"))))
} yield ()
}
def assertNesting(q: Rep[_], exp: Int): Unit = {
import slick.compiler.QueryCompiler
import slick.ast._
import slick.ast.Util._
val qc = new QueryCompiler(tdb.driver.queryCompiler.phases.takeWhile(_.name != "codeGen"))
val cs = qc.run(q.toNode)
val found = cs.tree.collect { case c: Comprehension => c }.length
if(found != exp)
throw cs.symbolNamer.use(new SlickTreeException(s"Found $found Comprehension nodes, should be $exp",
cs.tree, mark = (_.isInstanceOf[Comprehension]), removeUnmarked = false))
}
}
@@ -1,7 +1,5 @@
package com.typesafe.slick.testkit.util
import org.slf4j.MDC
import scala.language.existentials
import scala.concurrent.{Promise, ExecutionContext, Await, Future, blocking}
@@ -14,8 +12,10 @@ import java.lang.reflect.Method
import java.util.concurrent.{LinkedBlockingQueue, ThreadPoolExecutor, ExecutionException, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
import slick.SlickTreeException
import slick.dbio._
import slick.jdbc.JdbcBackend
import slick.lifted.Rep
import slick.util.DumpInfo
import slick.profile.{RelationalProfile, SqlProfile, Capability}
import slick.driver.JdbcProfile
@@ -25,6 +25,8 @@ import org.junit.runner.notification.RunNotifier
import org.junit.runners.model._
import org.junit.Assert
import org.slf4j.MDC
import org.reactivestreams.{Subscription, Subscriber, Publisher}
/** JUnit runner for the Slick driver test kit. */
@@ -159,6 +161,18 @@ sealed abstract class GenericTest[TDB >: Null <: TestDB](implicit TdbClass: Clas
final def mark[R, S <: NoStream, E <: Effect](id: String, f: => DBIOAction[R, S, E]): DBIOAction[R, S, E] =
mark[DBIOAction[R, S, E]](id, f.named(id))
def assertNesting(q: Rep[_], exp: Int): Unit = {
import slick.compiler.QueryCompiler
import slick.ast._
import slick.ast.Util._
val qc = new QueryCompiler(tdb.driver.queryCompiler.phases.takeWhile(_.name != "codeGen"))
val cs = qc.run(q.toNode)
val found = cs.tree.collect { case c: Comprehension => c }.length
if(found != exp)
throw cs.symbolNamer.use(new SlickTreeException(s"Found $found Comprehension nodes, should be $exp",
cs.tree, mark = (_.isInstanceOf[Comprehension]), removeUnmarked = false))
}
def rcap = RelationalProfile.capabilities
def scap = SqlProfile.capabilities
def jcap = JdbcProfile.capabilities
@@ -325,7 +339,7 @@ abstract class AsyncTest[TDB >: Null <: TestDB](implicit TdbClass: ClassTag[TDB]
def shouldNotBe(o: Any): Unit = fixStack(Assert.assertNotSame(o, v))
def should(f: T => Boolean): Unit = fixStack(Assert.assertTrue(f(v)))
def should(f: T => Boolean): Unit = fixStack(Assert.assertTrue("'should' assertion failed for value: "+v, f(v)))
def shouldFail(f: T => Unit): Unit = {
var ok = false
@@ -8,15 +8,17 @@ import slick.util.ConstArray
final case class Comprehension(sym: TermSymbol, from: Node, select: Node, where: Option[Node] = None,
groupBy: Option[Node] = None, orderBy: ConstArray[(Node, Ordering)] = ConstArray.empty,
having: Option[Node] = None,
distinct: Option[Node] = None,
fetch: Option[Node] = None, offset: Option[Node] = None) extends DefNode {
type Self = Comprehension
lazy val children = (ConstArray.newBuilder() + from + select ++ where ++ groupBy ++ orderBy.map(_._1) ++ having ++ fetch ++ offset).result
lazy val children = (ConstArray.newBuilder() + from + select ++ where ++ groupBy ++ orderBy.map(_._1) ++ having ++ distinct ++ fetch ++ offset).result
override def childNames =
Seq("from "+sym, "select") ++
where.map(_ => "where") ++
groupBy.map(_ => "groupBy") ++
orderBy.map("orderBy " + _._2).toSeq ++
having.map(_ => "having") ++
distinct.map(_ => "distinct") ++
fetch.map(_ => "fetch") ++
offset.map(_ => "offset")
protected[this] def rebuild(ch: ConstArray[Node]) = {
@@ -30,7 +32,9 @@ final case class Comprehension(sym: TermSymbol, from: Node, select: Node, where:
val newOrderBy = ch.slice(orderByOffset, orderByOffset + orderBy.length)
val havingOffset = orderByOffset + newOrderBy.length
val newHaving = ch.slice(havingOffset, havingOffset + having.productArity)
val fetchOffset = havingOffset + newHaving.length
val distinctOffset = havingOffset + newHaving.length
val newDistinct = ch.slice(distinctOffset, distinctOffset + distinct.productArity)
val fetchOffset = distinctOffset + newDistinct.length
val newFetch = ch.slice(fetchOffset, fetchOffset + fetch.productArity)
val offsetOffset = fetchOffset + newFetch.length
val newOffset = ch.slice(offsetOffset, offsetOffset + offset.productArity)
@@ -41,28 +45,30 @@ final case class Comprehension(sym: TermSymbol, from: Node, select: Node, where:
groupBy = newGroupBy.headOption,
orderBy = orderBy.zip(newOrderBy).map { case ((_, o), n) => (n, o) },
having = newHaving.headOption,
distinct = newDistinct.headOption,
fetch = newFetch.headOption,
offset = newOffset.headOption
)
}
def generators = ConstArray((sym, from))
override def getDumpInfo = super.getDumpInfo.copy(mainInfo = "")
protected[this] def rebuildWithSymbols(gen: ConstArray[TermSymbol]) = copy(sym = gen.head)
def withInferredType(scope: Type.Scope, typeChildren: Boolean): Self = {
// Assign type to "from" Node and compute the resulting scope
val f2 = from.infer(scope, typeChildren)
val genScope = scope + (sym -> f2.nodeType.asCollectionType.elementType)
// Assign types to "select", "where", "groupBy", "orderBy", "having", "fetch" and "offset" Nodes
// Assign types to "select", "where", "groupBy", "orderBy", "having", "distinct", "fetch" and "offset" Nodes
val s2 = select.infer(genScope, typeChildren)
val w2 = mapOrNone(where)(_.infer(genScope, typeChildren))
val g2 = mapOrNone(groupBy)(_.infer(genScope, typeChildren))
val o = orderBy.map(_._1)
val o2 = o.endoMap(_.infer(genScope, typeChildren))
val h2 = mapOrNone(having)(_.infer(genScope, typeChildren))
val distinct2 = mapOrNone(distinct)(_.infer(genScope, typeChildren))
val fetch2 = mapOrNone(fetch)(_.infer(genScope, typeChildren))
val offset2 = mapOrNone(offset)(_.infer(genScope, typeChildren))
// Check if the nodes changed
val same = (f2 eq from) && (s2 eq select) && w2.isEmpty && g2.isEmpty && (o2 eq o) && h2.isEmpty && fetch2.isEmpty && offset2.isEmpty
val same = (f2 eq from) && (s2 eq select) && w2.isEmpty && g2.isEmpty && (o2 eq o) && h2.isEmpty &&
distinct2.isEmpty && fetch2.isEmpty && offset2.isEmpty
val newType =
if(!hasType) CollectionType(f2.nodeType.asCollectionType.cons, s2.nodeType.asCollectionType.elementType)
else nodeType
@@ -74,6 +80,7 @@ final case class Comprehension(sym: TermSymbol, from: Node, select: Node, where:
groupBy = g2.orElse(groupBy),
orderBy = if(o2 eq o) orderBy else orderBy.zip(o2).map { case ((_, o), n) => (n, o) },
having = h2.orElse(having),
distinct = distinct2.orElse(distinct),
fetch = fetch2.orElse(fetch),
offset = offset2.orElse(offset)
) :@ newType
@@ -266,8 +266,8 @@ final case class Subquery(child: Node, condition: Subquery.Condition) extends Un
object Subquery {
sealed trait Condition
/** Always create a subquery */
case object Always extends Condition
/** Always create a subquery but allow purely aliasing projections to be pushed down */
case object Default extends Condition
/** A Subquery boundary below the mapping operation that adds a ROWNUM */
case object BelowRownum extends Condition
/** A Subquery boundary above the mapping operation that adds a ROWNUM */
@@ -276,31 +276,39 @@ object Subquery {
case object BelowRowNumber extends Condition
/** A Subquery boundary above the mapping operation that adds a ROW_NUMBER */
case object AboveRowNumber extends Condition
/** A Subquery boundary above a DISTINCT without explicit column specification */
case object AboveDistinct extends Condition
}
/** Common superclass for expressions of type (CollectionType(c, t), _) => CollectionType(c, t). */
abstract class FilteredQuery extends Node {
type Self >: this.type <: FilteredQuery
protected[this] def generator: TermSymbol
def from: Node
def generators = ConstArray((generator, from))
override def getDumpInfo = super.getDumpInfo.copy(mainInfo = this match {
case p: Product => p.productIterator.filterNot(n => n.isInstanceOf[Node] || n.isInstanceOf[Symbol]).mkString(", ")
case _ => ""
})
}
/** A FilteredQuery without a Symbol. */
abstract class SimpleFilteredQuery extends FilteredQuery with SimplyTypedNode {
type Self >: this.type <: SimpleFilteredQuery
def buildType = from.nodeType
}
/** A FilteredQuery with a Symbol. */
abstract class ComplexFilteredQuery extends FilteredQuery with DefNode {
type Self >: this.type <: ComplexFilteredQuery
protected[this] def generator: TermSymbol
def generators = ConstArray((generator, from))
def withInferredType(scope: Type.Scope, typeChildren: Boolean): Self = {
val from2 = from.infer(scope, typeChildren)
val genScope = scope + (generator -> from2.nodeType.asCollectionType.elementType)
val this2 = mapChildren { ch =>
val this2 = mapChildren { ch =>
if(ch eq from) from2 else ch.infer(genScope, typeChildren)
}
(this2 :@ (if(!hasType) this2.from.nodeType else nodeType)).asInstanceOf[Self]
}
}
/** A .filter call of type (CollectionType(c, t), Boolean) => CollectionType(c, t). */
final case class Filter(generator: TermSymbol, from: Node, where: Node) extends FilteredQuery with BinaryNode with DefNode {
final case class Filter(generator: TermSymbol, from: Node, where: Node) extends ComplexFilteredQuery with BinaryNode {
type Self = Filter
def left = from
def right = where
@@ -317,7 +325,7 @@ object Filter {
}
/** A .sortBy call of type (CollectionType(c, t), _) => CollectionType(c, t). */
final case class SortBy(generator: TermSymbol, from: Node, by: ConstArray[(Node, Ordering)]) extends FilteredQuery with DefNode {
final case class SortBy(generator: TermSymbol, from: Node, by: ConstArray[(Node, Ordering)]) extends ComplexFilteredQuery {
type Self = SortBy
lazy val children = from +: by.map(_._1)
protected[this] def rebuild(ch: ConstArray[Node]) =
@@ -370,25 +378,33 @@ final case class GroupBy(fromGen: TermSymbol, from: Node, by: Node, identity: Ty
}
/** A .take call. */
final case class Take(from: Node, count: Node) extends FilteredQuery with BinaryNode {
final case class Take(from: Node, count: Node) extends SimpleFilteredQuery with BinaryNode {
type Self = Take
def left = from
def right = count
protected[this] val generator = new AnonSymbol
override def childNames = Seq("from", "count")
protected[this] def rebuild(left: Node, right: Node) = copy(from = left, count = right)
}
/** A .drop call. */
final case class Drop(from: Node, count: Node) extends FilteredQuery with BinaryNode {
final case class Drop(from: Node, count: Node) extends SimpleFilteredQuery with BinaryNode {
type Self = Drop
def left = from
def right = count
protected[this] val generator = new AnonSymbol
override def childNames = Seq("from", "count")
protected[this] def rebuild(left: Node, right: Node) = copy(from = left, count = right)
}
/** A .distinct call of type (CollectionType(c, t), _) => CollectionType(c, t). */
final case class Distinct(generator: TermSymbol, from: Node, on: Node) extends ComplexFilteredQuery with BinaryNode {
type Self = Distinct
def left = from
def right = on
override def childNames = Seq("from", "on")
protected[this] def rebuild(left: Node, right: Node) = copy(from = left, on = right)
protected[this] def rebuildWithSymbols(gen: ConstArray[TermSymbol]) = copy(generator = gen(0))
}
/** A join expression. For joins without option extension, the type rule is
* (CollectionType(c, t), CollectionType(_, u)) => CollecionType(c, (t, u)).
* Option-extended left outer joins are typed as
@@ -305,13 +305,18 @@ class TypeUtil(val tpe: Type) extends AnyVal {
b.result
}
def containsSymbol(tss: scala.collection.Set[TypeSymbol]): Boolean = {
def existsType(f: Type => Boolean): Boolean =
if(f(tpe)) true else tpe match {
case t: AtomicType => false
case t => t.children.exists(_.existsType(f))
}
def containsSymbol(tss: scala.collection.Set[TypeSymbol]): Boolean =
if(tss.isEmpty) false else tpe match {
case NominalType(ts, exp) => tss.contains(ts) || exp.containsSymbol(tss)
case t: AtomicType => false
case t => t.children.exists(_.containsSymbol(tss))
}
}
}
object TypeUtil {
Oops, something went wrong.

0 comments on commit 9d8bbee

Please sign in to comment.