Permalink
Browse files

Remove TypedNode:

TypedNodes always carried a type and didn’t lose it, even when
performing untyped transformations. Now that we have fully typed
transformations throughout the query compiler, this is not needed
anymore and `TypedNode` can be subsumed into `SimplyTypedNode`. A
TypedNode is now a SimplyTypedNode that infers a fixed type (specified
at construction time). This type is intrinsic to the node and will not
be lost when copying the node (just like any other constructor
parameter) but the real node type can be lost or reassigned normally.

- Add some extra `infer()` calls to instantiations of former TypedNodes
  throughout the query compiler.

- Add a special case for `NullaryNode` to `Node.mapChildren` to make
  `infer()` (and other calls) more efficient for `LiteralNode` (and
  other, less common NullaryNodes).
  • Loading branch information...
szeiger committed Jun 22, 2015
1 parent 1c2d5db commit e26a7d052af71d817b938717b07222be307d018c
@@ -92,27 +92,27 @@ class JoinTest extends AsyncTest[RelationalTestDB] {
_ <- ys ++= Seq((1, "a"), (2, "b"), (3, "b"), (4, "d"), (5, "d"))
// Left outer, lift primitive value
q1 = (xs.map(_.b) joinLeft ys.map(_.b) on (_ === _)).to[Set]
r1 <- q1.result
r1 <- mark("q1", q1.result)
r1t: Set[(String, Option[String])] = r1
_ = r1 shouldBe Set(("a",Some("a")), ("b",Some("b")), ("c",None))
// Nested left outer, lift primitive value
q2 = ((xs.map(_.b) joinLeft ys.map(_.b) on (_ === _)) joinLeft ys.map(_.b) on (_._1 === _)).to[Set]
r2 <- q2.result
r2 <- mark("q2", q2.result)
r2t: Set[((String, Option[String]), Option[String])] = r2
_ = r2 shouldBe Set((("a",Some("a")),Some("a")), (("b",Some("b")),Some("b")), (("c",None),None))
// Left outer, lift non-primitive value
q3 = (xs joinLeft ys on (_.b === _.b)).to[Set]
r3 <- q3.result
r3 <- mark("q3", q3.result)
r3t: Set[((Int, String), Option[(Int, String)])] = r3
_ = r3 shouldBe Set(((3,"b"),Some((3,"b"))), ((3,"b"),Some((2,"b"))), ((5,"c"),None), ((1,"a"),Some((1,"a"))), ((4,"c"),None), ((2,"b"),Some((3,"b"))), ((2,"b"),Some((2,"b"))))
// Left outer, lift non-primitive value, then map to primitive
q4 = (xs joinLeft ys on (_.b === _.b)).map { case (x, yo) => (x.a, yo.map(_.a)) }.to[Set]
r4 <- q4.result
r4 <- mark("q4", q4.result)
r4t: Set[(Int, Option[Int])] = r4
_ = r4 shouldBe Set((4,None), (3,Some(2)), (2,Some(3)), (2,Some(2)), (3,Some(3)), (1,Some(1)), (5,None))
// Nested left outer, lift non-primitive value
q5 = ((xs joinLeft ys on (_.b === _.b)) joinLeft ys on (_._1.b === _.b)).to[Set]
r5 <- q5.result
r5 <- mark("q5", q5.result)
r5t: Set[(((Int, String), Option[(Int, String)]), Option[(Int, String)])] = r5
_ = r5 shouldBe Set(
(((1,"a"),Some((1,"a"))),Some((1,"a"))),
@@ -129,18 +129,18 @@ class JoinTest extends AsyncTest[RelationalTestDB] {
)
// Right outer, lift primitive value
q6 = (ys.map(_.b) joinRight xs.map(_.b) on (_ === _)).to[Set]
r6 <- q6.result
r6 <- mark("q6", q6.result)
r6t: Set[(Option[String], String)] = r6
_ = r6 shouldBe Set((Some("a"),"a"), (Some("b"),"b"), (None,"c"))
// Nested right outer, lift primitive value
// (left-associative; not symmetrical to the nested left outer case)
q7 = ((ys.map(_.b) joinRight xs.map(_.b) on (_ === _)) joinRight xs.map(_.b) on (_._2 === _)).to[Set]
r7 <- q7.result
r7 <- mark("q7", q7.result)
rt: Set[(Option[(Option[String], String)], String)] = r7
_ = r7 shouldBe Set((Some((Some("a"),"a")),"a"), (Some((Some("b"),"b")),"b"), (Some((None,"c")),"c"))
// Right outer, lift non-primitive value
q8 = (ys joinRight xs on (_.b === _.b)).to[Set]
r8 <- q8.result
r8 <- mark("q8", q8.result)
r8t: Set[(Option[(Int, String)], (Int, String))] = r8
_ = r8 shouldBe Set(
(Some((1,"a")), (1,"a")),
@@ -153,13 +153,13 @@ class JoinTest extends AsyncTest[RelationalTestDB] {
)
// Right outer, lift non-primitive value, then map to primitive
q9 = (ys joinRight xs on (_.b === _.b)).map { case (yo, x) => (yo.map(_.a), x.a) }.to[Set]
r9 <- q9.result
r9 <- mark("q9", q9.result)
r9t: Set[(Option[Int], Int)] = r9
_ = r9 shouldBe Set((None,4), (Some(2),3), (Some(3),2), (Some(2),2), (Some(3),3), (Some(1),1), (None,5))
// Nested right outer, lift non-primitive value
// (left-associative; not symmetrical to the nested left outer case)
q10 = ((ys joinRight xs on (_.b === _.b)) joinRight xs on (_._1.map(_.b) === _.b)).to[Set]
r10 <- q10.result
r10 <- mark("q10", q10.result)
r10t: Set[(Option[(Option[(Int, String)], (Int, String))], (Int, String))] = r10
_ = r10 shouldBe Set(
(Some((Some((1,"a")),(1,"a"))),(1,"a")),
@@ -176,12 +176,12 @@ class JoinTest extends AsyncTest[RelationalTestDB] {
)
// Full outer, lift primitive values
q11 = (xs.map(_.b) joinFull ys.map(_.b) on (_ === _)).to[Set]
r11 <- q11.result
r11 <- mark("q11", q11.result)
r11t: Set[(Option[String], Option[String])] = r11
_ = r11 shouldBe Set((Some("a"),Some("a")), (Some("b"),Some("b")), (Some("c"),None), (None,Some("d")))
// Full outer, lift non-primitive values
q12 = (xs joinFull ys on (_.b === _.b)).to[Set]
r12 <- q12.result
r12 <- mark("q12", q12.result)
r12t: Set[(Option[(Int, String)], Option[(Int, String)])] = r12
_ = r12 shouldBe Set(
(Some((1,"a")),Some((1,"a"))),
@@ -124,10 +124,10 @@ class NestingTest extends AsyncTest[RelationalTestDB] {
val q4bt: Query[Rep[Option[Int]], _, Seq] = q4b
val t2 = seq(
q1b.result.named("q1b").map(_ shouldBe r.map(t => Some(t)).map(_.getOrElse((0, "", None: Option[String])))),
q2b.result.named("q2b").map(_ shouldBe r.map(t => Some(t._1)).map(_.get)),
q3b.result.named("q3b").map(_ shouldBe r.map(t => t._3).filter(_.isDefined).map(_.get)),
q4b.result.named("q4b").map(_ shouldBe r.map(t => Some(t._3)).map(_.getOrElse(None: Option[String])))
mark("q1b", q1b.result).map(_ shouldBe r.map(t => Some(t)).map(_.getOrElse((0, "", None: Option[String])))),
mark("q2b", q2b.result).map(_ shouldBe r.map(t => Some(t._1)).map(_.get)),
mark("q3b", q3b.result).map(_ shouldBe r.map(t => t._3).filter(_.isDefined).map(_.get)),
mark("a4b", q4b.result).map(_ shouldBe r.map(t => Some(t._3)).map(_.getOrElse(None: Option[String])))
)
// Unpack result types
@@ -24,11 +24,11 @@ class PagingTest extends AsyncTest[RelationalTestDB] {
_ <- ids ++= (1 to 10)
_ <- mark("q1", q1.result).map(_ shouldBe (1 to 10).toList)
_ <- mark("q2", q2.result).map(_ shouldBe (1 to 5).toList)
_ <- ifCap(rcap.pagingDrop)(seq(
q3.result.map(_ shouldBe (6 to 10).toList),
q4.result.map(_ shouldBe (6 to 8).toList),
q5.result.map(_ shouldBe (4 to 5).toList)
))
_ <- ifCap(rcap.pagingDrop)(for {
_ <- mark("q3", q3.result).map(_ shouldBe (6 to 10).toList)
_ <- mark("q4", q4.result).map(_ shouldBe (6 to 8).toList)
_ <- mark("q5", q5.result).map(_ shouldBe (4 to 5).toList)
} yield ())
_ <- mark("q6", q6.result).map(_ shouldBe Nil)
} yield ()
}
@@ -82,9 +82,9 @@ final case class Comprehension(sym: TermSymbol, from: Node, select: Option[Node]
}
/** The row_number window function */
final case class RowNumber(by: Seq[(Node, Ordering)] = Seq.empty) extends TypedNode {
final case class RowNumber(by: Seq[(Node, Ordering)] = Seq.empty) extends SimplyTypedNode {
type Self = RowNumber
def tpe = ScalaBaseType.longType
def buildType = ScalaBaseType.longType
lazy val children = by.map(_._1)
protected[this] def rebuild(ch: IndexedSeq[Node]) =
copy(by = by.zip(ch).map{ case ((_, o), n) => (n, o) })
@@ -18,7 +18,7 @@ final case class Insert(tableSym: TermSymbol, table: Node, linear: Node) extends
}
/** A column in an Insert operation. */
final case class InsertColumn(children: IndexedSeq[Node], fs: FieldSymbol, tpe: Type) extends Node with TypedNode {
final case class InsertColumn(children: IndexedSeq[Node], fs: FieldSymbol, buildType: Type) extends Node with SimplyTypedNode {
type Self = InsertColumn
protected[this] def rebuild(ch: IndexedSeq[Node]) = copy(children = ch)
override def getDumpInfo = super.getDumpInfo.copy(mainInfo = fs.toString)
@@ -103,10 +103,10 @@ class FunctionSymbol(val name: String) extends TermSymbol {
}
/** Create a typed Apply of this Symbol */
def typed(tpe: Type, ch: Node*): Apply with TypedNode = Apply(this, ch)(tpe)
def typed(tpe: Type, ch: Node*): Apply = Apply(this, ch)(tpe)
/** Create a typed Apply of this Symbol */
def typed[T : ScalaBaseType](ch: Node*): Apply with TypedNode = Apply(this, ch)(implicitly[ScalaBaseType[T]])
def typed[T : ScalaBaseType](ch: Node*): Apply = Apply(this, ch)(implicitly[ScalaBaseType[T]])
override def toString = "Function "+name
}
@@ -7,11 +7,8 @@ import slick.util.{Logging, Dumpable, DumpInfo, GlobalConfig}
import Util._
import TypeUtil._
/**
* A node in the Slick AST.
*
* Every Node has a number of child nodes and an optional type annotation.
*/
/** A node in the Slick AST.
* Every Node has a number of child nodes and an optional type annotation. */
trait Node extends Dumpable {
type Self >: this.type <: Node
@@ -37,7 +34,7 @@ trait Node extends Dumpable {
/** Apply a mapping function to all children of this node and recreate the node with the new
* children. If all new children are identical to the old ones, this node is returned. If
* ``keepType`` is true, the type of this node is kept even when the children have changed. */
final def mapChildren(f: Node => Node, keepType: Boolean = false): Self = {
final def mapChildren(f: Node => Node, keepType: Boolean = false): Self = if(isInstanceOf[NullaryNode]) this else {
val n: Self = mapOrNone(children)(f).map(rebuild).getOrElse(this)
if(_type == UnassignedType || !keepType) n else (n :@ _type).asInstanceOf[Self]
}
@@ -59,9 +56,9 @@ trait Node extends Dumpable {
if(seenType || _type != UnassignedType) rebuild(children.toIndexedSeq) else this
/** Return this Node with a Type assigned (if no other type has been seen for it yet) or a typed copy. */
final def :@ (tpe: Type): Self = {
val n: Self = if(seenType && tpe != _type) rebuild(children.toIndexedSeq) else this
n._type = tpe
final def :@ (newType: Type): Self = {
val n: Self = if(seenType && newType != _type) rebuild(children.toIndexedSeq) else this
n._type = newType
n
}
@@ -83,20 +80,20 @@ trait Node extends Dumpable {
(n, args)
case _ => (super.toString, "")
}
val tpe = peekType
val t = peekType
val ch = this match {
// Omit path details unless dumpPaths is set
case Path(l @ (_ :: _ :: _)) if !GlobalConfig.dumpPaths => Vector.empty
case _ => childNames.zip(children).toVector
}
DumpInfo(objName, mainInfo, if(tpe != UnassignedType) ": " + tpe.toString else "", ch)
DumpInfo(objName, mainInfo, if(t != UnassignedType) ": " + t.toString else "", ch)
}
override final def toString = getDumpInfo.getNamePlusMainInfo
}
/** A Node whose children can be typed independently of each other and which can be typed without
* access to its scope. */
/** A Node which can be typed without access to its scope, and whose children can be typed
* independently of each other. */
trait SimplyTypedNode extends Node {
type Self >: this.type <: SimplyTypedNode
@@ -108,22 +105,6 @@ trait SimplyTypedNode extends Node {
}
}
/** A Node with a fixed type that cannot be overridden or removed. */
trait TypedNode extends Node {
def tpe: Type
override def nodeType: Type = {
val t = super.nodeType
if(t eq UnassignedType) tpe else t
}
def withInferredType(scope: SymbolScope, typeChildren: Boolean, retype: Boolean): Self =
mapChildren(_.infer(scope, typeChildren, retype), !retype)
override def hasType = (tpe ne UnassignedType) || super.hasType
override protected[this] def peekType: Type = super.peekType match {
case UnassignedType => tpe
case t => t
}
}
/** An expression that represents a conjunction of expressions. */
final case class ProductNode(children: Seq[Node]) extends SimplyTypedNode {
type Self = ProductNode
@@ -171,14 +152,14 @@ final case class StructNode(elements: IndexedSeq[(TermSymbol, Node)]) extends Si
* contains user-generated data or may change in future executions of what
* is otherwise the same query. A database back-end should usually turn
* volatile constants into bind variables. */
class LiteralNode(val tpe: Type, val value: Any, val volatileHint: Boolean = false) extends NullaryNode with TypedNode {
class LiteralNode(val buildType: Type, val value: Any, val volatileHint: Boolean = false) extends NullaryNode with SimplyTypedNode {
type Self = LiteralNode
override def getDumpInfo = super.getDumpInfo.copy(name = "LiteralNode", mainInfo = s"$value (volatileHint=$volatileHint)")
protected[this] def rebuild = new LiteralNode(tpe, value, volatileHint)
protected[this] def rebuild = new LiteralNode(buildType, value, volatileHint)
override def hashCode = tpe.hashCode() + (if(value == null) 0 else value.asInstanceOf[AnyRef].hashCode)
override def hashCode = buildType.hashCode() + (if(value == null) 0 else value.asInstanceOf[AnyRef].hashCode)
override def equals(o: Any) = o match {
case l: LiteralNode => tpe == l.tpe && value == l.value
case l: LiteralNode => buildType == l.buildType && value == l.value
case _ => false
}
}
@@ -456,9 +437,9 @@ final case class Select(in: Node, field: TermSymbol) extends UnaryNode with Simp
}
/** A function call expression. */
final case class Apply(sym: TermSymbol, children: Seq[Node])(val tpe: Type) extends TypedNode {
final case class Apply(sym: TermSymbol, children: Seq[Node])(val buildType: Type) extends SimplyTypedNode {
type Self = Apply
protected[this] def rebuild(ch: IndexedSeq[slick.ast.Node]) = copy(children = ch)(tpe)
protected[this] def rebuild(ch: IndexedSeq[slick.ast.Node]) = copy(children = ch)(buildType)
override def getDumpInfo = super.getDumpInfo.copy(mainInfo = sym.toString)
}
@@ -503,27 +484,27 @@ object FwdPath {
}
/** A Node representing a database table. */
final case class TableNode(schemaName: Option[String], tableName: String, identity: TableIdentitySymbol, driverTable: Any, baseIdentity: TableIdentitySymbol) extends NullaryNode with TypedNode {
final case class TableNode(schemaName: Option[String], tableName: String, identity: TableIdentitySymbol, driverTable: Any, baseIdentity: TableIdentitySymbol) extends NullaryNode with SimplyTypedNode {
type Self = TableNode
def tpe = CollectionType(TypedCollectionTypeConstructor.seq, NominalType(identity, UnassignedType))
def buildType = CollectionType(TypedCollectionTypeConstructor.seq, NominalType(identity, UnassignedType))
def rebuild = copy()
override def getDumpInfo = super.getDumpInfo.copy(name = "Table", mainInfo = schemaName.map(_ + ".").getOrElse("") + tableName)
}
/** A node that represents an SQL sequence. */
final case class SequenceNode(name: String)(val increment: Long) extends NullaryNode with TypedNode {
final case class SequenceNode(name: String)(val increment: Long) extends NullaryNode with SimplyTypedNode {
type Self = SequenceNode
def tpe = ScalaBaseType.longType
def buildType = ScalaBaseType.longType
def rebuild = copy()(increment)
}
/** A Query of this special Node represents an infinite stream of consecutive
* numbers starting at the given number. This is used as an operand for
* zipWithIndex. It is not exposed directly in the query language because it
* cannot be represented in SQL outside of a 'zip' operation. */
final case class RangeFrom(start: Long = 1L) extends NullaryNode with TypedNode {
final case class RangeFrom(start: Long = 1L) extends NullaryNode with SimplyTypedNode {
type Self = RangeFrom
def tpe = CollectionType(TypedCollectionTypeConstructor.seq, ScalaBaseType.longType)
def buildType = CollectionType(TypedCollectionTypeConstructor.seq, ScalaBaseType.longType)
def rebuild = copy()
}
@@ -600,7 +581,7 @@ final case class GetOrElse(child: Node, default: () => Any) extends UnaryNode wi
/** A compiled statement with a fixed type, a statement string and
* driver-specific extra data. */
final case class CompiledStatement(statement: String, extra: Any, tpe: Type) extends NullaryNode with TypedNode {
final case class CompiledStatement(statement: String, extra: Any, buildType: Type) extends NullaryNode with SimplyTypedNode {
type Self = CompiledStatement
def rebuild = copy()
override def getDumpInfo =
@@ -625,7 +606,7 @@ final case class RebuildOption(discriminator: Node, data: Node) extends BinaryNo
}
/** A parameter from a QueryTemplate which gets turned into a bind variable. */
final case class QueryParameter(extractor: (Any => Any), tpe: Type) extends NullaryNode with TypedNode {
final case class QueryParameter(extractor: (Any => Any), buildType: Type) extends NullaryNode with SimplyTypedNode {
type Self = QueryParameter
def rebuild = copy()
override def getDumpInfo = super.getDumpInfo.copy(mainInfo = extractor + "@" + System.identityHashCode(extractor))
@@ -638,7 +619,7 @@ object QueryParameter {
* on two primitive values. The given Nodes must also be of type `LiteralNode` or
* `QueryParameter`. */
def constOp[T](name: String)(op: (T, T) => T)(l: Node, r: Node)(implicit tpe: ScalaBaseType[T]): Node = (l, r) match {
case (LiteralNode(lv) :@ (lt: TypedType[_]), LiteralNode(rv) :@ (rt: TypedType[_])) if lt.scalaType == tpe && rt.scalaType == tpe => LiteralNode[T](op(lv.asInstanceOf[T], rv.asInstanceOf[T]))
case (LiteralNode(lv) :@ (lt: TypedType[_]), LiteralNode(rv) :@ (rt: TypedType[_])) if lt.scalaType == tpe && rt.scalaType == tpe => LiteralNode[T](op(lv.asInstanceOf[T], rv.asInstanceOf[T])).infer()
case (LiteralNode(lv) :@ (lt: TypedType[_]), QueryParameter(re, rt: TypedType[_])) if lt.scalaType == tpe && rt.scalaType == tpe =>
QueryParameter(new (Any => T) {
def apply(param: Any) = op(lv.asInstanceOf[T], re(param).asInstanceOf[T])
@@ -14,15 +14,15 @@ class ExpandConditionals extends Phase {
def tr(n: Node): Node = n.mapChildren(tr, keepType = true) match {
// Expand multi-column SilentCasts
case cast @ Library.SilentCast(ch) :@ Type.Structural(ProductType(typeCh)) =>
val elems = typeCh.zipWithIndex.map { case (t, idx) => tr(Library.SilentCast.typed(t, ch.select(ElementSymbol(idx+1)).infer())) }
val elems = typeCh.zipWithIndex.map { case (t, idx) => tr(Library.SilentCast.typed(t, ch.select(ElementSymbol(idx+1)).infer()).infer()) }
ProductNode(elems).infer()
case Library.SilentCast(ch) :@ Type.Structural(StructType(typeCh)) =>
val elems = typeCh.map { case (sym, t) => (sym, tr(Library.SilentCast.typed(t, ch.select(sym).infer()))) }
val elems = typeCh.map { case (sym, t) => (sym, tr(Library.SilentCast.typed(t, ch.select(sym).infer()).infer())) }
StructNode(elems).infer()
// Optimize trivial SilentCasts
case Library.SilentCast(v :@ tpe) :@ tpe2 if tpe.structural == tpe2.structural => v
case Library.SilentCast(Library.SilentCast(ch)) :@ tpe => tr(Library.SilentCast.typed(tpe, ch))
case Library.SilentCast(Library.SilentCast(ch)) :@ tpe => tr(Library.SilentCast.typed(tpe, ch).infer())
case Library.SilentCast(LiteralNode(None)) :@ (tpe @ OptionType.Primitive(_)) => LiteralNode(tpe, None)
// Expand multi-column IfThenElse
Oops, something went wrong.

0 comments on commit e26a7d0

Please sign in to comment.