Permalink
Browse files

Fix type bugs and enforce correct types in the query compiler:

- `Select` was still a special case for type-checking, keeping its
  node type when rebuilding. It now loses the type as it should, which
  is important when re-typing a tree after performing a change somewhere
  inside it. Only a single change in `assignUniqueSymbols` (where we
  have to preserve all field types because the table types are not yet
  known) is required to support this change.

- Add `NodeOps.replaceFold` for performing stateful bottom-up
  transformations of ASTs. It is used by `NodeOps.replaceInvalidate` for
  a transformation that collects TypeSymbols to invalidate and unassigns
  types from references containing these symbols.

- Remove the recently added `retype` flag from `NodeOps.replace`. Now
  that all nodes will drop their type when rebuilding a subtree, it is
  not needed anymore. We can `infer()` the type after transforming.

- Fix lots of small typing bugs in many compiler phases.

- Add a new `WellTyped` parameter to the `CompilerState` to indicate
  which part of the AST is supposed to be well-typed after running a
  phase.

- Add a new optional phase `verifyTypes` to check well-typedness of the
  AST. It can be run automatically after every regular phase by setting
  `slick.verifyTypes = true` in the application config.

- Hardcode config defaults in `GlobalConfig` for cases where
  `reference.conf` cannot be resolved (e.g. in macros).
  • Loading branch information...
szeiger committed Jun 24, 2015
1 parent e26a7d0 commit f170783d830c97479ada2de00b34dde9593f737b
Showing with 339 additions and 202 deletions.
  1. +1 −0 common-test-resources/application.conf
  2. +1 −0 common-test-resources/logback.xml
  3. +1 −0 slick-testkit/src/doctest/resources/logback.xml
  4. +12 −10 slick-testkit/src/main/scala/com/typesafe/slick/testkit/tests/NewQuerySemanticsTest.scala
  5. +3 −0 slick/src/main/resources/reference.conf
  6. +1 −1 slick/src/main/scala/slick/ast/Node.scala
  7. +23 −17 slick/src/main/scala/slick/ast/Type.scala
  8. +48 −36 slick/src/main/scala/slick/ast/Util.scala
  9. +1 −0 slick/src/main/scala/slick/compiler/AssignUniqueSymbols.scala
  10. +0 −7 slick/src/main/scala/slick/compiler/CreateAggregates.scala
  11. +13 −13 slick/src/main/scala/slick/compiler/CreateResultSetMapping.scala
  12. +63 −41 slick/src/main/scala/slick/compiler/ExpandConditionals.scala
  13. +4 −5 slick/src/main/scala/slick/compiler/ExpandRecords.scala
  14. +1 −1 slick/src/main/scala/slick/compiler/ExpandSums.scala
  15. +4 −2 slick/src/main/scala/slick/compiler/ExpandTables.scala
  16. +13 −6 slick/src/main/scala/slick/compiler/FlattenProjections.scala
  17. +6 −9 slick/src/main/scala/slick/compiler/HoistClientOps.scala
  18. +1 −2 slick/src/main/scala/slick/compiler/InferTypes.scala
  19. +1 −1 slick/src/main/scala/slick/compiler/MergeToComprehensions.scala
  20. +5 −6 slick/src/main/scala/slick/compiler/PruneProjections.scala
  21. +28 −6 slick/src/main/scala/slick/compiler/QueryCompiler.scala
  22. +4 −6 slick/src/main/scala/slick/compiler/RemoveFieldNames.scala
  23. +18 −12 slick/src/main/scala/slick/compiler/ResolveZipJoins.scala
  24. +5 −5 slick/src/main/scala/slick/compiler/RewriteBooleans.scala
  25. +12 −9 slick/src/main/scala/slick/compiler/RewriteJoins.scala
  26. +1 −1 slick/src/main/scala/slick/compiler/SpecializeParameters.scala
  27. +50 −0 slick/src/main/scala/slick/compiler/VerifyTypes.scala
  28. +2 −1 slick/src/main/scala/slick/memory/MemoryQueryingProfile.scala
  29. +14 −5 slick/src/main/scala/slick/util/GlobalConfig.scala
  30. +3 −0 slick/src/main/scala/slick/util/Logging.scala
@@ -2,6 +2,7 @@ slick {
ansiDump = true
unicodeDump = true
sqlIndent = true
verifyTypes = true
}
tsql {
@@ -36,6 +36,7 @@
<logger name="slick.compiler.CodeGen" level="${log.qcomp.codeGen:-inherited}" />
<logger name="slick.compiler.RemoveFieldNames" level="${log.qcomp.removeFieldNames:-inherited}" />
<logger name="slick.compiler.InsertCompiler" level="${log.qcomp.insertCompiler:-inherited}" />
<logger name="slick.compiler.VerifyTypes" level="${log.qcomp.verifyTypes:-inherited}" />
<logger name="slick.jdbc.JdbcBackend.statement" level="${log.jdbc.statement:-info}" />
<logger name="slick.jdbc.JdbcBackend.benchmark" level="${log.jdbc.bench:-info}" />
<logger name="slick.jdbc.StatementInvoker.result" level="${log.jdbc.result:-info}" />
@@ -36,6 +36,7 @@
<logger name="slick.compiler.CodeGen" level="${log.qcomp.codeGen:-inherited}" />
<logger name="slick.compiler.RemoveFieldNames" level="${log.qcomp.removeFieldNames:-inherited}" />
<logger name="slick.compiler.InsertCompiler" level="${log.qcomp.insertCompiler:-inherited}" />
<logger name="slick.compiler.VerifyTypes" level="${log.qcomp.verifyTypes:-inherited}" />
<logger name="slick.jdbc.JdbcBackend.statement" level="${log.jdbc.statement:-info}" />
<logger name="slick.jdbc.JdbcBackend.benchmark" level="${log.jdbc.bench:-info}" />
<logger name="slick.jdbc.StatementInvoker.result" level="${log.jdbc.result:-info}" />
@@ -72,27 +72,29 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
val qa = for {
c <- coffees.take(3)
} yield (c.supID, (c.name, 42))
val qa2 = coffees.take(3).map(_.name).take(2)
val qb = qa.take(2).map(_._2)
val qb2 = qa.map(n => n).take(2).map(_._2)
val qc = qa.map(_._2).take(2)
val a1 = seq(
qa.result.map(_.toSet).map { ra =>
mark("qa", qa.result).map(_.toSet).map { ra =>
ra.size shouldBe 3
// No sorting, so result contents can vary
ra shouldAllMatch { case (s: Int, (i: String, 42)) => () }
},
qb.result.map(_.toSet).map { rb =>
mark("qa2", qa2.result).map(_.toSet).map(_.size shouldBe 2),
mark("qb", qb.result).map(_.toSet).map { rb =>
rb.size shouldBe 2
// No sorting, so result contents can vary
rb shouldAllMatch { case (i: String, 42) => () }
},
qb2.result.map(_.toSet).map { rb2 =>
mark("qb2", qb2.result).map(_.toSet).map { rb2 =>
rb2.size shouldBe 2
// No sorting, so result contents can vary
rb2 shouldAllMatch { case (i: String, 42) => () }
},
qc.result.map(_.toSet).map { rc =>
mark("qc", qc.result).map(_.toSet).map { rc =>
rc.size shouldBe 2
// No sorting, so result contents can vary
rc shouldAllMatch { case (i: String, 42) => () }
@@ -116,7 +118,7 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
} yield (c.name, s.city, c2.name)
def a2 = seq(
q0.result.named("Plain table").map(_.toSet).map { r0 =>
mark("q0", q0.result).named("q0: Plain table").map(_.toSet).map { r0 =>
r0 shouldBe Set(
("Colombian", 101, 799, 1, 0),
("French_Roast", 49, 799, 2, 0),
@@ -125,7 +127,7 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
("French_Roast_Decaf", 49, 999, 5, 0)
)
},
q1.result.named("Plain implicit join").map(_.toSet).map { r1 =>
mark("q1", q1.result).named("q1: Plain implicit join").map(_.toSet).map { r1 =>
r1 shouldBe Set(
(("Colombian","Groundsville:"),("Colombian",101,799,1,0),(101,"Acme, Inc.","99 Market Street"),799),
(("Colombian","Mendocino:"),("Colombian",101,799,1,0),(49,"Superior Coffee","1 Party Place"),799),
@@ -136,7 +138,7 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
)
},
ifCap(rcap.pagingNested) {
q1b.result.named("Explicit join with condition").map { r1b =>
mark("q1b", q1b.result).named("q1b: Explicit join with condition").map { r1b =>
r1b.toSet shouldBe Set(
("French_Roast","Mendocino","Colombian"),
("French_Roast","Mendocino","French_Roast"),
@@ -233,21 +235,21 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
val q6 = coffees.flatMap(c => suppliers)
def a4 = seq(
q5.result.map(_.toSet).map { r5 =>
mark("q5", q5.result).named("q5: Implicit self-join").map(_.toSet).map { r5 =>
r5 shouldBe Set(
(("Colombian",101,799,1,0),("Colombian",101,799,1,0)),
(("Colombian",101,799,1,0),("French_Roast",49,799,2,0)),
(("French_Roast",49,799,2,0),("Colombian",101,799,1,0)),
(("French_Roast",49,799,2,0),("French_Roast",49,799,2,0))
)
},
q5b.result.named("Explicit self-join with condition").map(_.toSet).map { r5b =>
mark("q5b", q5b.result).named("q5b: Explicit self-join with condition").map(_.toSet).map { r5b =>
r5b shouldBe Set(
(("Colombian",101,799,1,0),("Colombian",101,799,1,0)),
(("French_Roast",49,799,2,0),("French_Roast",49,799,2,0))
)
},
q6.result.named("Unused outer query result, unbound TableQuery").map(_.toSet).map { r6 =>
mark("q6", q6.result).named("q6: Unused outer query result, unbound TableQuery").map(_.toSet).map { r6 =>
r6 shouldBe Set(
(101,"Acme, Inc.","99 Market Street"),
(49,"Superior Coffee","1 Party Place"),
@@ -11,6 +11,9 @@ slick {
# Use multi-line, indented formatting for SQL statements
sqlIndent = false
# Verify types after each query compiler phase
verifyTypes = false
}
slick.driver.MySQL {
@@ -428,7 +428,7 @@ final case class Select(in: Node, field: TermSymbol) extends UnaryNode with Simp
type Self = Select
def child = in
override def childNames = Seq("in")
protected[this] def rebuild(child: Node) = copy(in = child) :@ nodeType
protected[this] def rebuild(child: Node) = copy(in = child)
override def getDumpInfo = Path.unapply(this) match {
case Some(l) => super.getDumpInfo.copy(name = "Path", mainInfo = l.reverseIterator.mkString("."))
case None => super.getDumpInfo
@@ -160,6 +160,11 @@ abstract class TypedCollectionTypeConstructor[C[_]](val classTag: ClassTag[C[_]]
.replaceFirst("^scala.collection.mutable.", "m.")
.replaceFirst("^scala.collection.generic.", "g.")
def createBuilder[E : ClassTag]: Builder[E, C[E]]
override def hashCode = classTag.hashCode() * 10
override def equals(o: Any) = o match {
case o: TypedCollectionTypeConstructor[_] => classTag == o.classTag
case _ => false
}
}
class ErasedCollectionTypeConstructor[C[_]](canBuildFrom: CanBuild[Any, C[Any]], classTag: ClassTag[C[_]]) extends TypedCollectionTypeConstructor[C](classTag) {
@@ -194,6 +199,11 @@ final class MappedScalaType(val baseType: Type, val mapper: MappedScalaType.Mapp
}
def children: Seq[Type] = Seq(baseType)
override def select(sym: TermSymbol) = baseType.select(sym)
override def hashCode = baseType.hashCode() + mapper.hashCode() + classTag.hashCode()
override def equals(o: Any) = o match {
case o: MappedScalaType => baseType == o.baseType && mapper == o.mapper && classTag == o.classTag
case _ => false
}
}
object MappedScalaType {
@@ -258,6 +268,8 @@ object TypedType {
}
class TypeUtil(val tpe: Type) extends AnyVal {
import TypeUtil.typeToTypeUtil
def asCollectionType: CollectionType = tpe match {
case c: CollectionType => c
case _ => throw new SlickException("Expected a collection type, found "+tpe)
@@ -275,33 +287,27 @@ class TypeUtil(val tpe: Type) extends AnyVal {
g(tpe)
}
@inline def replace(f: PartialFunction[Type, Type]): Type = TypeUtilOps.replace(tpe, f)
@inline def collect[T](pf: PartialFunction[Type, T]): Iterable[T] = TypeUtilOps.collect(tpe, pf)
@inline def collectAll[T](pf: PartialFunction[Type, Seq[T]]): Iterable[T] = collect[Seq[T]](pf).flatten
def replace(f: PartialFunction[Type, Type]): Type =
f.applyOrElse(tpe, { case t: Type => t.mapChildren(_.replace(f)) }: PartialFunction[Type, Type])
def collect[T](pf: PartialFunction[Type, T]): Iterable[T] = {
val b = new ArrayBuffer[T]
tpe.foreach(pf.andThen[Unit]{ case t => b += t }.orElse[Type, Unit]{ case _ => () })
b
}
def collectAll[T](pf: PartialFunction[Type, Seq[T]]): Iterable[T] = collect[Seq[T]](pf).flatten
}
object TypeUtil {
implicit def typeToTypeUtil(tpe: Type) = new TypeUtil(tpe)
implicit def typeToTypeUtil(tpe: Type): TypeUtil = new TypeUtil(tpe)
/** An extractor for node types */
object :@ {
def unapply(n: Node) = Some((n, n.nodeType))
}
}
object TypeUtilOps {
import TypeUtil.typeToTypeUtil
def replace(tpe: Type, f: PartialFunction[Type, Type]): Type =
f.applyOrElse(tpe, { case t: Type => t.mapChildren(_.replace(f)) }: PartialFunction[Type, Type])
def collect[T](tpe: Type, pf: PartialFunction[Type, T]): Iterable[T] = {
val b = new ArrayBuffer[T]
tpe.foreach(pf.andThen[Unit]{ case t => b += t }.orElse[Type, Unit]{ case _ => () })
b
}
}
trait SymbolScope {
def + (entry: (TermSymbol, Type)): SymbolScope
def get(sym: TermSymbol): Option[Type]
@@ -1,5 +1,7 @@
package slick.ast
import slick.ast.TypeUtil.:@
import scala.language.implicitConversions
import scala.collection.mutable.ArrayBuffer
@@ -26,13 +28,55 @@ object Util {
final class NodeOps(val tree: Node) extends AnyVal {
import Util._
@inline def collect[T](pf: PartialFunction[Node, T], stopOnMatch: Boolean = false): Seq[T] =
NodeOps.collect(tree, pf, stopOnMatch)
@inline def collect[T](pf: PartialFunction[Node, T], stopOnMatch: Boolean = false): Seq[T] = {
val b = new ArrayBuffer[T]
def f(n: Node): Unit = pf.andThen[Unit] { case t =>
b += t
if(!stopOnMatch) n.children.foreach(f)
}.orElse[Node, Unit]{ case _ =>
n.children.foreach(f)
}.apply(n)
f(tree)
b
}
def collectAll[T](pf: PartialFunction[Node, Seq[T]]): Seq[T] = collect[Seq[T]](pf).flatten
def replace(f: PartialFunction[Node, Node], keepType: Boolean = false, bottomUp: Boolean = false, retype: Boolean = false): Node =
NodeOps.replace(tree, f, keepType, bottomUp, retype)
def replace(f: PartialFunction[Node, Node], keepType: Boolean = false, bottomUp: Boolean = false): Node = {
def g(n: Node): Node = n.mapChildren(_.replace(f, keepType, bottomUp), keepType)
if(bottomUp) f.applyOrElse(g(tree), identity[Node]) else f.applyOrElse(tree, g)
}
/** Replace nodes in a bottom-up traversal with an extra state value that gets passed through the
* traversal. Types are never kept or rebuilt when a node changes.
*
* @param f The replacement function that takes the current Node (whose children have already
* been transformed), the current state, and the original (untransformed) version of
* the Node. */
def replaceFold[T](z: T)(f: PartialFunction[(Node, T, Node), (Node, T)]): (Node, T) = {
var v: T = z
val ch: IndexedSeq[Node] = tree.children.map { n =>
val (n2, v2) = n.replaceFold(v)(f)
v = v2
n2
}(collection.breakOut)
val t2 = tree.withChildren(ch)
f.applyOrElse((t2, v, tree), (t: (Node, T, Node)) => (t._1, t._2))
}
/** Replace nodes in a bottom-up traversal while invalidating TypeSymbols. Any later references
* to the invalidated TypeSymbols have their types unassigned, so that the whole tree can be
* retyped afterwards to get the correct new TypeSymbols in. */
def replaceInvalidate(f: PartialFunction[(Node, Set[TypeSymbol], Node), (Node, Set[TypeSymbol])]): Node = {
def containsTS(t: Type, invalid: Set[TypeSymbol]): Boolean = t match {
case NominalType(ts, exp) => invalid.contains(ts) || containsTS(exp, invalid)
case t => t.children.exists(ch => containsTS(ch, invalid))
}
replaceFold(Set.empty[TypeSymbol])(f.orElse {
case ((n: Ref), invalid, _) if containsTS(n.nodeType, invalid) => (n.untyped, invalid)
case ((n: Select), invalid, _) if containsTS(n.nodeType, invalid) => (n.untyped, invalid)
})._1
}
def foreach[U](f: (Node => U)): Unit = {
def g(n: Node) {
@@ -57,35 +101,3 @@ final class NodeOps(val tree: Node) extends AnyVal {
case (s, n) => Select(n, s)
}
}
object NodeOps {
import Util._
// These methods should be in the class but 2.10.0-RC1 took away the ability
// to use closures in value classes
def collect[T](tree: Node, pf: PartialFunction[Node, T], stopOnMatch: Boolean): Seq[T] = {
val b = new ArrayBuffer[T]
def f(n: Node): Unit = pf.andThen[Unit] { case t =>
b += t
if(!stopOnMatch) n.children.foreach(f)
}.orElse[Node, Unit]{ case _ =>
n.children.foreach(f)
}.apply(n)
f(tree)
b
}
def replace(tree: Node, f: PartialFunction[Node, Node], keepType: Boolean, bottomUp: Boolean, retype: Boolean): Node =
if(bottomUp) {
val t2 = tree.mapChildren(n => replace(n, f, keepType, bottomUp, retype), keepType && !retype)
val t3 = if(retype) t2.infer() else t2
f.applyOrElse(t2, identity[Node])
} else {
def g(n: Node) = {
val n2 = n.mapChildren(n => replace(n, f, keepType, bottomUp, retype), keepType && !retype)
if(retype) n2.infer() else n2
}
f.applyOrElse(tree, g)
}
}
@@ -52,6 +52,7 @@ class AssignUniqueSymbols extends Phase {
case a: AnonSymbol => defs.getOrElse(a, a)
case s => s
}
case n: Select => n.mapChildren(tr(_, replace)) :@ n.nodeType
case n => n.mapChildren(tr(_, replace))
}
}
@@ -39,12 +39,6 @@ class CreateAggregates extends Phase {
case Vector((s, n)) => Map(s -> List(s1))
case _ =>
val len = sources.length
// Join(1, Join(2, Join(3, Join(4, 5))))
// 1 -> s1._1
// 2 -> s1._2._1
// 3 -> s1._2._2._1
// 4 -> s1._2._2._2._1
// 5 -> s1._2._2._2._2
val it = Iterator.iterate(s1)(_ => ElementSymbol(2))
sources.zipWithIndex.map { case ((s, _), i) =>
val l = List.iterate(s1, i+1)(_ => ElementSymbol(2))
@@ -61,7 +55,6 @@ class CreateAggregates extends Phase {
n2
}
//FilteredQuery (Filter, SortBy, Take, Drop) -- GroupBy, Join, Union, Bind
case n => n
}
Oops, something went wrong.

0 comments on commit f170783

Please sign in to comment.