Permalink
Browse files

Prevent repetition of bind variables used in GROUP BY:

Any expression from a GROUP BY clause may be inlined into later clauses
(like SELECT or ORDER BY) which can lead to duplication of expressions.
This is usually not a problem, in fact SQL is designed to allow reuse
of GROUP BY expressions without aggregation functions in these clauses.
This change prevents such reuse when a GROUP BY expression contains a
bind variable because the database could no longer know that the
expressions are indeed the same. If a leaked GROUP BY key is detected,
we push the source of the GroupBy into a subquery.

Test in AggregateTest.testFusedGroupBy. Fixes #1282.
  • Loading branch information...
szeiger committed Sep 23, 2015
1 parent 97b35a8 commit 8f8ae8bad12b5e5d3248e5ec12e4be85ee224b15
@@ -282,6 +282,32 @@ class AggregateTest extends AsyncTest[RelationalTestDB] {
} yield ()
}
def testFusedGroupBy = {
class A(tag: Tag) extends Table[(Int, Int)](tag, "A_FUSEDGROUPBY") {
def id = column[Int]("id", O.PrimaryKey)
def value = column[Int]("value")
def * = (id, value)
}
val as = TableQuery[A]
val q1 = as.map(t => t.value + LiteralColumn(1).bind).groupBy(identity).map(_._1)
val q2 = as.map(t => (t.value, t.value + LiteralColumn(1).bind)).groupBy(identity).map(_._1._2)
val q3 = as.map(t => (t.value, t.value + LiteralColumn(1).bind)).groupBy(identity).map(_._1._1)
if(tdb.driver == H2Driver) {
assertNesting(q1, 2)
assertNesting(q2, 2)
assertNesting(q3, 1)
}
DBIO.seq(
as.schema.create,
as ++= Seq((1, 10), (2, 20), (3, 20)),
mark("q1", q1.result).map(_.toSet shouldBe Set(11, 21)),
mark("q2", q2.result).map(_.toSet shouldBe Set(11, 21)),
mark("q3", q3.result).map(_.toSet shouldBe Set(10, 20))
)
}
def testDistinct = {
class A(tag: Tag) extends Table[String](tag, "A_DISTINCT") {
def id = column[Int]("id", O.PrimaryKey)
@@ -733,10 +733,10 @@ final case class RebuildOption(discriminator: Node, data: Node) extends BinaryNo
}
/** A parameter from a QueryTemplate which gets turned into a bind variable. */
final case class QueryParameter(extractor: (Any => Any), buildType: Type) extends NullaryNode with SimplyTypedNode {
final case class QueryParameter(extractor: (Any => Any), buildType: Type, id: TermSymbol = new AnonSymbol) extends NullaryNode with SimplyTypedNode {
type Self = QueryParameter
def rebuild = copy()
override def getDumpInfo = super.getDumpInfo.copy(mainInfo = extractor + "@" + System.identityHashCode(extractor))
override def getDumpInfo = super.getDumpInfo.copy(mainInfo = s"$id $extractor")
}
object QueryParameter {
@@ -747,17 +747,17 @@ object QueryParameter {
* `QueryParameter`. */
def constOp[T](name: String)(op: (T, T) => T)(l: Node, r: Node)(implicit tpe: ScalaBaseType[T]): Node = (l, r) match {
case (LiteralNode(lv) :@ (lt: TypedType[_]), LiteralNode(rv) :@ (rt: TypedType[_])) if lt.scalaType == tpe && rt.scalaType == tpe => LiteralNode[T](op(lv.asInstanceOf[T], rv.asInstanceOf[T])).infer()
case (LiteralNode(lv) :@ (lt: TypedType[_]), QueryParameter(re, rt: TypedType[_])) if lt.scalaType == tpe && rt.scalaType == tpe =>
case (LiteralNode(lv) :@ (lt: TypedType[_]), QueryParameter(re, rt: TypedType[_], _)) if lt.scalaType == tpe && rt.scalaType == tpe =>
QueryParameter(new (Any => T) {
def apply(param: Any) = op(lv.asInstanceOf[T], re(param).asInstanceOf[T])
override def toString = s"($lv $name $re)"
}, tpe)
case (QueryParameter(le, lt: TypedType[_]), LiteralNode(rv) :@ (rt: TypedType[_])) if lt.scalaType == tpe && rt.scalaType == tpe =>
case (QueryParameter(le, lt: TypedType[_], _), LiteralNode(rv) :@ (rt: TypedType[_])) if lt.scalaType == tpe && rt.scalaType == tpe =>
QueryParameter(new (Any => T) {
def apply(param: Any) = op(le(param).asInstanceOf[T], rv.asInstanceOf[T])
override def toString = s"($le $name $rv)"
}, tpe)
case (QueryParameter(le, lt: TypedType[_]), QueryParameter(re, rt: TypedType[_])) if lt.scalaType == tpe && rt.scalaType == tpe =>
case (QueryParameter(le, lt: TypedType[_], _), QueryParameter(re, rt: TypedType[_], _)) if lt.scalaType == tpe && rt.scalaType == tpe =>
QueryParameter(new (Any => T) {
def apply(param: Any) = op(le(param).asInstanceOf[T], re(param).asInstanceOf[T])
override def toString = s"($le $name $re)"
@@ -101,20 +101,41 @@ class MergeToComprehensions extends Phase {
case Bind(s1, GroupBy(s2, f1, b1, ts1), Pure(str1, ts2)) =>
val (c1, replacements1) = mergeFilterWhere(f1, true)
logger.debug("Merging GroupBy into Comprehension:", Ellipsis(n, List(0, 0)))
val b2 = applyReplacements(b1, replacements1, c1)
val (c1a, replacements1a, b2a) = {
val b2 = applyReplacements(b1, replacements1, c1)
// Check whether groupBy keys containing bind variables are returned for further use
// and push the current Comprehension into a subquery if this is the case.
val leakedPaths =
str1.collect({ case FwdPath(s :: ElementSymbol(1) :: rest) if s == s1 => rest }, stopOnMatch = true)
val isParam = leakedPaths.nonEmpty && ({
logger.debug("Leaked paths to GroupBy keys: " + leakedPaths.map(l => ("_" :: l).mkString(".")).mkString(", "))
val targets = leakedPaths.map(_.foldLeft(b2)(_ select _))
targets.indexWhere(_.findNode {
case _: QueryParameter => true
case n: LiteralNode => n.volatileHint
case _ => false
}.isDefined) >= 0
})
if(isParam) {
logger.debug("Pushing GroupBy source into subquery to avoid repeated parameter")
val (c1a, replacements1a) = toSubquery(c1, replacements1)
val b2a = applyReplacements(b1, replacements1a, c1a)
(c1a, replacements1a, b2a)
} else (c1, replacements1, b2)
}
val str2 = str1.replace {
case Aggregate(_, FwdPath(s :: ElementSymbol(2) :: Nil), v) if s == s1 =>
applyReplacements(v, replacements1, c1).replace {
applyReplacements(v, replacements1a, c1a).replace {
case Apply(f: AggregateFunctionSymbol, ConstArray(ch)) :@ tpe =>
Apply(f, ConstArray(ch match {
case StructNode(ConstArray(ch, _*)) => ch._2
case n => n
}))(tpe)
}
case FwdPath(s :: ElementSymbol(1) :: rest) if s == s1 =>
rest.foldLeft(b2) { case (n, s) => n.select(s) }.infer()
rest.foldLeft(b2a) { case (n, s) => n.select(s) }.infer()
}
val c2 = c1.copy(groupBy = Some(ProductNode(ConstArray(b2)).flatten), select = Pure(str2, ts2)).infer()
val c2 = c1a.copy(groupBy = Some(ProductNode(ConstArray(b2a)).flatten), select = Pure(str2, ts2)).infer()
logger.debug("Merged GroupBy into Comprehension:", c2)
val StructNode(defs2) = str2
val replacements = defs2.iterator.map { case (f, _) => (ts2, f) -> f }.toMap
@@ -397,7 +397,7 @@ trait JdbcStatementBuilderComponent { driver: JdbcDriver =>
}
b" end)"
case OptionApply(ch) => expr(ch, skipParens)
case QueryParameter(extractor, JdbcType(ti, option)) =>
case QueryParameter(extractor, JdbcType(ti, option), _) =>
b +?= { (p, idx, param) =>
if(option) ti.setOption(extractor(param).asInstanceOf[Option[Any]], p, idx)
else ti.setValue(extractor(param), p, idx)
@@ -229,7 +229,7 @@ class QueryInterpreter(db: HeapBackend#Database, params: Any) extends Logging {
if(opt && !c.elseClause.nodeType.asInstanceOf[ScalaType[_]].nullable) Option(res)
else res
}
case QueryParameter(extractor, _) =>
case QueryParameter(extractor, _, _) =>
extractor(params)
case Library.SilentCast(ch) =>
val chV = run(ch)

0 comments on commit 8f8ae8b

Please sign in to comment.