Permalink
Browse files

Avoid subqueries around zipWithIndex operations

- Eliminate or move subquery boundaries in `reorderOperations` to
  allow more operations to be fused.

- Use a top-down transformation instead of bottom-up in `removeTakeDrop`
  to allow fusion of nested Take and Drop operations, as originally
  intended.

- Use correct type Long instead of Int for zipWithIndex indexes in
  QueryInterpreter.

Tests in JoinTest.testZip, PagingTest.testRawPagination,
NewQuerySemanticsTest.testNewFusion.
  • Loading branch information...
szeiger committed Aug 11, 2015
1 parent 5ca0e5c commit 764b405fc089e610663ac96bff0c5e59ac9a3b97
@@ -226,21 +226,25 @@ class JoinTest extends AsyncTest[RelationalTestDB] {
} yield (c.id, p.category)
_ <- mark("q2", q2.result).map(_ shouldBe List((1,-1), (2,1), (3,2), (4,2)))
q3 = for {
(c, p) <- categories zip posts
(c, p) <- categories.sortBy(_.id) zip posts.sortBy(_.id)
} yield (c.id, p.category)
_ <- mark("q3", q3.result).map(_ shouldBe List((1, -1), (3, 1), (2, 2), (4, 3)))
_ <- mark("q3", q3.result).map(_ shouldBe List((1, -1), (2, 1), (3, 2), (4, 3)))
q4 = for {
res <- categories.zipWith(posts, (c: Categories, p: Posts) => (c.id, p.category))
res <- categories.sortBy(_.id).zipWith(posts.sortBy(_.id), (c: Categories, p: Posts) => (c.id, p.category))
} yield res
_ <- mark("q4", q4.result).map(_ shouldBe List((1, -1), (3, 1), (2, 2), (4, 3)))
_ <- mark("q4", q4.result).map(_ shouldBe List((1, -1), (2, 1), (3, 2), (4, 3)))
q5 = for {
(c, i) <- categories.sortBy(_.id).zipWithIndex
} yield (c.id, i)
_ <- mark("q5", q5.result).map(_ shouldBe List((1,0), (2,1), (3,2), (4,3)))
q5b = for {
(c, i) <- categories.zipWithIndex
} yield (c.id, i)
_ <- mark("q5", q5.result).map(_ shouldBe List((1,0), (3,1), (2,2), (4,3)))
_ <- mark("q5b", q5b.result).map(_.map(_._2).toSet shouldBe Set(0L, 1L, 2L, 3L))
q6 = for {
((c, p), i) <- (categories zip posts).zipWithIndex
((c, p), i) <- (categories.sortBy(_.id) zip posts.sortBy(_.id)).zipWithIndex
} yield (c.id, p.category, i)
_ <- mark("q6", q6.result).map(_ shouldBe List((1, -1, 0), (3, 1, 1), (2, 2, 2), (4, 3, 3)))
_ <- mark("q6", q6.result).map(_ shouldBe List((1, -1, 0), (2, 1, 1), (3, 2, 2), (4, 3, 3)))
} yield ()
}
@@ -502,6 +502,7 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
val q14 = q13.to[Set]
val q15 = (as.map(a => a.id.?).filter(_ < 2) unionAll as.map(a => a.id.?).filter(_ > 2)).map(_.get).to[Set]
val q16 = (as.map(a => a.id.?).filter(_ < 2) unionAll as.map(a => a.id.?).filter(_ > 2)).map(_.getOrElse(-1)).to[Set].filter(_ =!= 42)
val q17 = as.sortBy(_.id).zipWithIndex.filter(_._2 < 2L).map { case (a, i) => (a.id, i) }
if(tdb.driver == H2Driver) {
assertNesting(q1, 1)
@@ -528,6 +529,7 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
assertNesting(q14, 2)
assertNesting(q15, 2)
assertNesting(q16, 2)
assertNesting(q17, 2)
}
for {
@@ -558,6 +560,7 @@ class NewQuerySemanticsTest extends AsyncTest[RelationalTestDB] {
_ <- mark("q14", q14.result).map(_ shouldBe Set(1, 3))
_ <- mark("q15", q15.result).map(_ shouldBe Set(1, 3))
_ <- mark("q16", q16.result).map(_ shouldBe Set(1, 3))
_ <- ifCap(rcap.zip)(mark("q17", q17.result).map(_ shouldBe Seq((1,0), (2,1))))
} yield ()
}
@@ -19,6 +19,7 @@ class PagingTest extends AsyncTest[RelationalTestDB] {
val q4b = q1.drop(5).take(3).sortBy(_.id)
def q5 = q1 take 5 drop 3
val q6 = q1 take 0
val q7 = ids.filter(_.id > 3).sortBy(_.id).take(3)
for {
_ <- ids.schema.create
@@ -32,6 +33,7 @@ class PagingTest extends AsyncTest[RelationalTestDB] {
_ <- mark("q5", q5.result).map(_ shouldBe (4 to 5).toList)
} yield ())
_ <- mark("q6", q6.result).map(_ shouldBe Nil)
_ <- mark("q7", q7.result).map(_ shouldBe List(4, 5, 6))
} yield ()
}
@@ -222,8 +222,14 @@ object Subquery {
sealed trait Condition
/** Always create a subquery */
case object Always extends Condition
/** Create a subquery if the current Comprehension contains a GROUP BY, ORDER BY or HAVING clause */
case object AboveGroupBy extends Condition
/** A Subquery boundary below the mapping operation that adds a ROWNUM */
case object BelowRownum extends Condition
/** A Subquery boundary above the mapping operation that adds a ROWNUM */
case object AboveRownum extends Condition
/** A Subquery boundary below the mapping operation that adds a ROW_NUMBER */
case object BelowRowNumber extends Condition
/** A Subquery boundary above the mapping operation that adds a ROW_NUMBER */
case object AboveRowNumber extends Condition
}
/** Common superclass for expressions of type (CollectionType(c, t), _) => CollectionType(c, t). */
@@ -296,6 +296,13 @@ class TypeUtil(val tpe: Type) extends AnyVal {
}
def collectAll[T](pf: PartialFunction[Type, Seq[T]]): Iterable[T] = collect[Seq[T]](pf).flatten
def containsSymbol(tss: scala.collection.Set[TypeSymbol]): Boolean = {
if(tss.isEmpty) false else tpe match {
case NominalType(ts, exp) => tss.contains(ts) || exp.containsSymbol(tss)
case t => t.children.exists(_.containsSymbol(tss))
}
}
}
object TypeUtil {
@@ -13,10 +13,11 @@ class RemoveTakeDrop extends Phase {
val name = "removeTakeDrop"
def apply(state: CompilerState) = state.map { n =>
val n2 = n.replaceInvalidate {
case (n @ TakeDrop(from, t, d), invalid, _) =>
val invalid = mutable.Set[TypeSymbol]()
def tr(n: Node): Node = n.replace {
case n @ TakeDrop(from, t, d) =>
logger.debug(s"""Translating "drop $d, then take $t" to zipWithIndex operation:""", n)
val fromRetyped = from.infer()
val fromRetyped = tr(from).infer()
val from2 = fromRetyped match {
case b: Bind => b
case n =>
@@ -41,8 +42,13 @@ class RemoveTakeDrop extends Phase {
logger.debug(s"""Translated "drop $d, then take $t" to zipWithIndex operation:""", b2)
val invalidate = fromRetyped.nodeType.collect { case NominalType(ts, _) => ts }
logger.debug("Invalidating TypeSymbols: "+invalidate.mkString(", "))
(b2, invalid ++ invalidate)
invalid ++= invalidate
b2
case (n: Ref) if n.nodeType.containsSymbol(invalid) => n.untyped
case n @ Select(in, f) if n.nodeType.containsSymbol(invalid) => Select(tr(in), f)
}
val n2 = tr(n)
logger.debug("After removeTakeDrop without inferring:", n2)
n2.infer()
}
@@ -11,7 +11,9 @@ class ReorderOperations extends Phase {
def apply(state: CompilerState) = state.map(convert)
def convert(tree: Node): Node = tree.replace({
def convert(tree: Node): Node = tree.replace({ case n => convert1(n) }, keepType = true, bottomUp = true)
def convert1(tree: Node): Node = tree match {
// Push Bind into Union
case n @ Bind(s1, Union(l1, r1, all), sel) =>
logger.debug("Pushing Bind into both sides of a Union", Ellipsis(n, List(0, 0), List(0, 1)))
@@ -41,6 +43,48 @@ class ReorderOperations extends Phase {
logger.debug("Pushed CollectionCast into both sides of a Union", Ellipsis(n2, List(0, 0), List(1, 0)))
n2
// Remove Subquery boundary on top of TableNode and Join
case Subquery(n @ (_: TableNode | _: Join), _) => n
// Push aliasing / literal projection into Subquery
case n @ Bind(s, Subquery(from, cond), Pure(StructNode(defs), ts1)) if isAliasingOrLiteral(s, defs) =>
Subquery(n.copy(from = from), cond).infer()
// If a Filter checks an upper bound of a ROWNUM, push it into the AboveRownum boundary
case filter @ Filter(s1,
sq @ Subquery(bind @ Bind(bs1, from1, Pure(StructNode(defs1), ts1)), Subquery.AboveRownum),
Apply(Library.<= | Library.<, Seq(Select(Ref(rs), f1), v1)))
if rs == s1 && defs1.find {
case (f, n) if f == f1 => isRownumCalculation(n)
case _ => false
}.isDefined =>
sq.copy(child = filter.copy(from = bind)).infer()
// Push a BelowRowNumber boundary into SortBy
case sq @ Subquery(n: SortBy, Subquery.BelowRowNumber) =>
n.copy(from = convert1(sq.copy(child = n.from))).infer()
// Push a BelowRowNumber boundary into Filter
case sq @ Subquery(n: Filter, Subquery.BelowRowNumber) =>
n.copy(from = convert1(sq.copy(child = n.from))).infer()
case n => n
}, keepType = true, bottomUp = true)
}
def isAliasingOrLiteral(base: TermSymbol, defs: IndexedSeq[(TermSymbol, Node)]) = {
val r = defs.iterator.map(_._2).forall {
case FwdPath(s :: _) if s == base => true
case _: LiteralNode => true
case _: QueryParameter => true
case _ => false
}
logger.debug("Bind from "+base+" is aliasing / literal: "+r)
r
}
def isRownumCalculation(n: Node): Boolean = n match {
case Apply(Library.+ | Library.-, ch) => ch.exists(isRownumCalculation)
case _: RowNumber => true
case _ => false
}
}
@@ -16,6 +16,9 @@ class ResolveZipJoins(rownumStyle: Boolean = false) extends Phase {
type State = Boolean
val name = "resolveZipJoins"
val condAbove: Subquery.Condition = if(rownumStyle) Subquery.AboveRownum else Subquery.AboveRowNumber
val condBelow: Subquery.Condition = if(rownumStyle) Subquery.BelowRownum else Subquery.BelowRowNumber
def apply(state: CompilerState) = {
val n2 = state.tree.replace({
case b @ Bind(s1,
@@ -47,8 +50,8 @@ class ResolveZipJoins(rownumStyle: Boolean = false) extends Phase {
val idxExpr =
if(offset == 1L) RowNumber()
else Library.-.typed[Long](RowNumber(), LiteralNode(1L - offset))
val lbind = Bind(ls, Subquery(from, Subquery.Always), Pure(StructNode(defs :+ (idxSym, idxExpr))))
Bind(s1, Subquery(lbind, Subquery.Always), p.replace {
val lbind = Bind(ls, Subquery(from, condBelow), Pure(StructNode(defs :+ (idxSym, idxExpr))))
Bind(s1, Subquery(lbind, condAbove), p.replace {
case Select(Ref(s), ElementSymbol(1)) if s == s1 => Ref(s1)
case Select(Ref(s), ElementSymbol(2)) if s == s1 => Select(Ref(s1), idxSym)
}).infer()
@@ -391,9 +391,9 @@ trait JdbcStatementBuilderComponent { driver: JdbcDriver =>
}
b" end)"
case RowNumber(by) =>
b"row_number() over("
if(by.isEmpty) b"order by (select 1)"
else buildOrderByClause(by)
b"row_number() over(order by "
if(by.isEmpty) b"(select 1)"
else b.sep(by, ", "){ case (n, o) => buildOrdering(n, o) }
b")"
case p @ Path(path) =>
val (base, rest) = path.foldRight[(Option[TermSymbol], List[TermSymbol])]((None, Nil)) {
@@ -64,7 +64,7 @@ class QueryInterpreter(db: HeapBackend#Database, params: Any) extends Logging {
b.result()
case Join(_, _, left, RangeFrom(0), JoinType.Zip, LiteralNode(true)) =>
val leftV = run(left).asInstanceOf[Coll]
leftV.zipWithIndex.map { case (l, r) => new ProductValue(Vector(l, r)) }
leftV.zipWithIndex.map { case (l, r) => new ProductValue(Vector(l, r.toLong)) }
case Join(_, _, left, right, JoinType.Zip, LiteralNode(true)) =>
val leftV = run(left).asInstanceOf[Coll]
val rightV = run(right).asInstanceOf[Coll]

0 comments on commit 764b405

Please sign in to comment.