Permalink
Browse files

Don’t skip `expandSums` if `expandTables` creates an `OptionApply`

- The rather expensive `exapndSums` phase is skipped if
  `assignUniqueSymbols` does not detect any nested Options in the AST
  that would need to be rewritten by `expandSums`. This fails, however,
  when `expandTables` injects a new `OptionApply` operation into an AST
  that did not already contain nested Options. The fix is to override
  the previously set flag in this case so that `expandSums` will run.

- We also add an optimization to `expandSums` that removes identity
  OptionFold/OptionApply combinations like the one produced by returning
  `Rep.None[Int]` in the test case. Inserting a `SilentCast` is the
  easiest solution to keeping the types correct before `expandRecords`.
  This requires a change in `MemoryCodeGen` to deal with these extra
  casts. An alternative approach would be to invalidate and rebuild the
  types recursively.

Test in JoinTest.testJoin. Fixes #1345.
  • Loading branch information...
szeiger committed Dec 11, 2015
1 parent 9db8774 commit b0e27e6a581ac5368835460a6e8fced5d768b1e8
@@ -59,6 +59,10 @@ class JoinTest extends AsyncTest[RelationalTestDB] {
if a1.id === a4.id
} yield a1.id).to[Set]
_ <- mark("q4", q4.result).map(_ shouldBe Set(1, 2, 3, 4))
q5 = (for {
c <- categories
} yield (c, Rep.None[Int])).sortBy(_._1.id)
_ <- mark("q5", q5.result.map(_.map(_._1._1))).map(_ shouldBe List(1,2,3,4))
} yield ()
}
@@ -103,14 +103,14 @@ class NestingTest extends AsyncTest[RelationalTestDB] {
val q4t: Query[Rep[Option[Option[Int]]], _, Seq] = q4
val q5t: Query[(Rep[Option[Int]], Rep[Option[String]]), _, Seq] = q5
val t1 = seq(
q1.result.named("q1").map(_ shouldBe r.map(t => Some(t))),
q1a2.result.named("q1a2").map(_ shouldBe r.map(t => Some(Some(t)))),
q2.result.named("q2").map(_ shouldBe r.map(t => Some(t._1))),
q2a2.result.named("q2a2").map(_ shouldBe r.map(t => Some(Some(t._1)))),
q3.result.named("q3").map(_ shouldBe r.map(t => t._3)),
q4.result.named("q4").map(_ shouldBe r.map(t => Some(t._3))),
q5.result.named("q5").map(_ shouldBe r.map(t => (t._3, Some(t._2))))
lazy val t1 = seq(
mark("q1", q1.result).map(_ shouldBe r.map(t => Some(t))),
mark("q1a2", q1a2.result).map(_ shouldBe r.map(t => Some(Some(t)))),
mark("q2", q2.result).map(_ shouldBe r.map(t => Some(t._1))),
mark("q2a2", q2a2.result).map(_ shouldBe r.map(t => Some(Some(t._1)))),
mark("q3", q3.result).map(_ shouldBe r.map(t => t._3)),
mark("q4", q4.result).map(_ shouldBe r.map(t => Some(t._3))),
mark("q5", q5.result).map(_ shouldBe r.map(t => (t._3, Some(t._2))))
)
// Get plain values out
@@ -123,7 +123,7 @@ class NestingTest extends AsyncTest[RelationalTestDB] {
val q3bt: Query[Rep[Int], _, Seq] = q3b
val q4bt: Query[Rep[Option[Int]], _, Seq] = q4b
val t2 = seq(
lazy val t2 = seq(
mark("q1b", q1b.result).map(_ shouldBe r.map(t => Some(t)).map(_.getOrElse((0, "", None: Option[String])))),
mark("q2b", q2b.result).map(_ shouldBe r.map(t => Some(t._1)).map(_.get)),
mark("q3b", q3b.result).map(_ shouldBe r.map(t => t._3).filter(_.isDefined).map(_.get)),
@@ -141,9 +141,9 @@ class NestingTest extends AsyncTest[RelationalTestDB] {
val q2c = q2.map(io => io + 42)
val q3c = q3.map(so => so + 10)
val t3 = seq(
q2c.result.named("q2c").map(_ shouldBe r.map(t => Some(t._1)).map(_.map(_ + 42))),
q3c.result.named("q3c").map(_ shouldBe r.map(t => t._3).map(_.map(_ + 10)))
lazy val t3 = seq(
mark("q2c", q2c.result).map(_ shouldBe r.map(t => Some(t._1)).map(_.map(_ + 42))),
mark("q3c", q3c.result).map(_ shouldBe r.map(t => t._3).map(_.map(_ + 10)))
)
// Use Option.map
@@ -162,7 +162,7 @@ class NestingTest extends AsyncTest[RelationalTestDB] {
val q3dt: Query[Rep[Option[(Rep[Int], Rep[Int], ConstColumn[Int])]], _, Seq] = q3d
val q4dt: Query[Rep[Option[Int]], _, Seq] = q4d
val t4 = seq(
lazy val t4 = seq(
q1d.result.named("q1d").map(_ shouldBe r.map(t => Some(t)).map(_.map(_._1))),
q1d2.result.named("q1d2").map(_ shouldBe r.map(t => Some(t)).map(_.map(x => (x._1, x._2, x._3)))),
q2d.result.named("q2d").map(_ shouldBe r.map(t => Some(t._1)).map(_.map(_ + 1))),
@@ -179,11 +179,11 @@ class NestingTest extends AsyncTest[RelationalTestDB] {
val q1e2t: Query[Rep[Option[Int]], _, Seq] = q1e2
val q2et: Query[Rep[Option[Int]], _, Seq] = q2e
val t5 = seq(
q1e1.result.named("q1e1").map(_ shouldBe r.map(t => Some(t)).map { to => to.flatMap { t => Some(t._2) }}),
q1e2.result.named("q1e2").map(_ shouldBe r.map(t => Some(t)).map { to => to.flatMap { t => t._3 }}),
q1e3.result.named("q1e3").map(_ shouldBe r.map(t => Some(t)).map(to => Some(to)).map(_.flatMap(identity))),
q2e.result.named("q2e").map(_ shouldBe r.map(t => Some(t._1)).map { io => io.flatMap { i => Some(i) }})
lazy val t5 = seq(
mark("q1e1", q1e1.result).map(_ shouldBe r.map(t => Some(t)).map { to => to.flatMap { t => Some(t._2) }}),
mark("q1e2", q1e2.result).map(_ shouldBe r.map(t => Some(t)).map { to => to.flatMap { t => t._3 }}),
mark("q1e3", q1e3.result).map(_ shouldBe r.map(t => Some(t)).map(to => Some(to)).map(_.flatMap(identity))),
mark("q2e", q2e.result).map(_ shouldBe r.map(t => Some(t._1)).map { io => io.flatMap { i => Some(i) }})
)
// Use Option.flatten
@@ -200,7 +200,7 @@ class NestingTest extends AsyncTest[RelationalTestDB] {
val q2f2t: Query[Rep[Option[Int]], _, Seq] = q2f2
val q2f3t: Query[Rep[Option[Int]], _, Seq] = q2f3
val t6 = seq(
lazy val t6 = seq(
q1f1.result.named("q1f1").map(_ shouldBe Vector(Some(Some((1,"1",Some(1)))), Some(Some((2,"2",Some(2)))), Some(Some((3,"3",None))))),
q1f2.result.named("q1f2").map(_ shouldBe r.map(t => Some(t)).map { to => Some(to).flatten }),
q1f3.result.named("q1f3").map(_ shouldBe r.map(t => Some(t)).map { to => Some(to) }.map(_.flatten)),
@@ -36,6 +36,10 @@ class ExpandSums extends Phase {
multi = true
IfThenElse(ConstArray(pred, then1, buildMultiColumnNone(tpe))) :@ tpe
// Identity OptionFold/OptionApply combination -> remove
case OptionFold(from, LiteralNode(None) :@ OptionType(ScalaBaseType.nullType), oa @ OptionApply(Ref(s)), gen) if s == gen =>
silentCast(oa.nodeType, from)
// Primitive OptionFold representing GetOrElse -> translate to GetOrElse
case OptionFold(from :@ OptionType.Primitive(_), LiteralNode(v), Ref(s), gen) if s == gen =>
GetOrElse(from, () => v).infer()
@@ -13,62 +13,69 @@ import scala.collection.mutable
class ExpandTables extends Phase {
val name = "expandTables"
def apply(state: CompilerState) = state.map { n => ClientSideOp.mapServerSide(n) { tree =>
// Find table fields
val structs = tree.collect[(TypeSymbol, (FieldSymbol, Type))] {
case s @ Select(_ :@ (n: NominalType), sym: FieldSymbol) => n.sourceNominalType.sym -> (sym -> s.nodeType)
}.toSeq.groupBy(_._1).map { case (ts, v) => (ts, NominalType(ts, StructType(ConstArray.from(v.map(_._2).toMap)))) }
logger.debug("Found Selects for NominalTypes: "+structs.keySet.mkString(", "))
def apply(state: CompilerState) = {
var createdOption = false
val tables = new mutable.HashMap[TableIdentitySymbol, (TermSymbol, Node)]
var expandDistinct = false
def tr(tree: Node): Node = tree.replace {
case t: TableExpansion =>
val ts = t.table.asInstanceOf[TableNode].identity
tables += ((ts, (t.generator, t.columns)))
t.table :@ CollectionType(t.nodeType.asCollectionType.cons, structs(ts))
case r: Ref => r.untyped
case d: Distinct =>
if(d.nodeType.existsType { case NominalType(_: TableIdentitySymbol, _) => true; case _ => false })
expandDistinct = true
d.mapChildren(tr)
/** Create an expression that copies a structured value, expanding tables in it. */
def createResult(expansions: collection.Map[TableIdentitySymbol, (TermSymbol, Node)], path: Node, tpe: Type): Node = tpe match {
case p: ProductType =>
ProductNode(p.elements.zipWithIndex.map { case (t, i) => createResult(expansions, Select(path, ElementSymbol(i+1)), t) })
case NominalType(tsym: TableIdentitySymbol, _) if expansions contains tsym =>
val (sym, exp) = expansions(tsym)
exp.replace { case Ref(s) if s == sym => path }
case tpe: NominalType => createResult(expansions, path, tpe.structuralView)
case m: MappedScalaType =>
TypeMapping(createResult(expansions, path, m.baseType), m.mapper, m.classTag)
case OptionType(el) =>
val gen = new AnonSymbol
createdOption = true
OptionFold(path, LiteralNode.nullOption, OptionApply(createResult(expansions, Ref(gen), el)), gen)
case _ => path
}
val tree2 = tr(tree).infer()
logger.debug("With correct table types:", tree2)
logger.debug("Table expansions: " + tables.mkString(", "))
// Perform star expansion in Distinct
val tree3 = if(!expandDistinct) tree2 else {
logger.debug("Expanding tables in Distinct")
tree2.replace({
case Distinct(s, f, o) => Distinct(s, f, createResult(tables, Ref(s), o.nodeType))
}, bottomUp = true).infer()
}
val s2 = state.map { n => ClientSideOp.mapServerSide(n) { tree =>
// Find table fields
val structs = tree.collect[(TypeSymbol, (FieldSymbol, Type))] {
case s @ Select(_ :@ (n: NominalType), sym: FieldSymbol) => n.sourceNominalType.sym -> (sym -> s.nodeType)
}.toSeq.groupBy(_._1).map { case (ts, v) => (ts, NominalType(ts, StructType(ConstArray.from(v.map(_._2).toMap)))) }
logger.debug("Found Selects for NominalTypes: "+structs.keySet.mkString(", "))
// Perform star expansion in query result
if(!tree.nodeType.existsType { case NominalType(_: TableIdentitySymbol, _) => true; case _ => false }) tree3 else {
logger.debug("Expanding tables in result type")
// Create a mapping that expands the tables
val sym = new AnonSymbol
val mapping = createResult(tables, Ref(sym), tree3.nodeType.asCollectionType.elementType)
.infer(Type.Scope(sym -> tree3.nodeType.asCollectionType.elementType))
Bind(sym, tree3, Pure(mapping)).infer()
}
}}.withWellTyped(true)
val tables = new mutable.HashMap[TableIdentitySymbol, (TermSymbol, Node)]
var expandDistinct = false
def tr(tree: Node): Node = tree.replace {
case t: TableExpansion =>
val ts = t.table.asInstanceOf[TableNode].identity
tables += ((ts, (t.generator, t.columns)))
t.table :@ CollectionType(t.nodeType.asCollectionType.cons, structs(ts))
case r: Ref => r.untyped
case d: Distinct =>
if(d.nodeType.existsType { case NominalType(_: TableIdentitySymbol, _) => true; case _ => false })
expandDistinct = true
d.mapChildren(tr)
}
val tree2 = tr(tree).infer()
logger.debug("With correct table types:", tree2)
logger.debug("Table expansions: " + tables.mkString(", "))
// Perform star expansion in Distinct
val tree3 = if(!expandDistinct) tree2 else {
logger.debug("Expanding tables in Distinct")
tree2.replace({
case Distinct(s, f, o) => Distinct(s, f, createResult(tables, Ref(s), o.nodeType))
}, bottomUp = true).infer()
}
/** Create an expression that copies a structured value, expanding tables in it. */
def createResult(expansions: collection.Map[TableIdentitySymbol, (TermSymbol, Node)], path: Node, tpe: Type): Node = tpe match {
case p: ProductType =>
ProductNode(p.elements.zipWithIndex.map { case (t, i) => createResult(expansions, Select(path, ElementSymbol(i+1)), t) })
case NominalType(tsym: TableIdentitySymbol, _) if expansions contains tsym =>
val (sym, exp) = expansions(tsym)
exp.replace { case Ref(s) if s == sym => path }
case tpe: NominalType => createResult(expansions, path, tpe.structuralView)
case m: MappedScalaType =>
TypeMapping(createResult(expansions, path, m.baseType), m.mapper, m.classTag)
case OptionType(el) =>
val gen = new AnonSymbol
OptionFold(path, LiteralNode.nullOption, OptionApply(createResult(expansions, Ref(gen), el)), gen)
case _ => path
// Perform star expansion in query result
if(!tree.nodeType.existsType { case NominalType(_: TableIdentitySymbol, _) => true; case _ => false }) tree3 else {
logger.debug("Expanding tables in result type")
// Create a mapping that expands the tables
val sym = new AnonSymbol
val mapping = createResult(tables, Ref(sym), tree3.nodeType.asCollectionType.elementType)
.infer(Type.Scope(sym -> tree3.nodeType.asCollectionType.elementType))
Bind(sym, tree3, Pure(mapping)).infer()
}
}}.withWellTyped(true)
if(createdOption) s2 + (Phase.assignUniqueSymbols -> state.get(Phase.assignUniqueSymbols).get.copy(nonPrimitiveOption = true))
else s2
}
}
@@ -63,6 +63,7 @@ trait MemoryQueryingDriver extends BasicDriver with MemoryQueryingProfile { driv
case Bind(gen, g: GroupBy, p @ Pure((_: ProductNode | _: StructNode), _)) =>
val p2 = transformCountAll(gen, p)
if(p2 eq p) n else Bind(gen, g, p2).infer(typeChildren = true)
case Library.SilentCast(n :@ tpe1) :@ tpe2 if tpe1 == tpe2 => n
case n => n
}

0 comments on commit b0e27e6

Please sign in to comment.