Permalink
Browse files

Select columns in the order in which the ResultConverters need them

Slick 3.1.0 abandoned the fixed result columns at the top level of a
query in order to remove unnecessary and duplicate columns. As a
side-effect the columns were randomly reordered. This turns out to be
a problem for `forceInsertQuery`, which expects the columns to be in
the specified order to match the insert operation’s column list. In
order to fix this problem `removeFieldNames` now uses a column order
which is derived from the ResultConverters.

Fixes #1338. Test in InsertTest.testForced.
  • Loading branch information...
szeiger committed Dec 15, 2015
1 parent 1673ff7 commit 55cbecca2fd3f83f8325bbb3f063b066cdbcfe16
@@ -85,29 +85,35 @@ class InsertTest extends AsyncTest[JdbcTestDB] {
}
def testForced = {
class T(tname: String)(tag: Tag) extends Table[(Int, String)](tag, tname) {
class T(tname: String)(tag: Tag) extends Table[(Int, String, Int, Boolean, String, String, Int)](tag, tname) {
def id = column[Int]("id", O.AutoInc, O.PrimaryKey)
def name = column[String]("name")
def * = (id, name)
def ins = (id, name)
def i1 = column[Int]("i1")
def b = column[Boolean]("b")
def s1 = column[String]("s1", O.Length(10,varying=true))
def s2 = column[String]("s2", O.Length(10,varying=true))
def i2 = column[Int]("i2")
def * = (id, name, i1, b, s1, s2, i2)
def ins = (id, name, i1, b, s1, s2, i2)
}
val ts = TableQuery(new T("t_forced")(_))
val src = TableQuery(new T("src_forced")(_))
seq(
(ts.schema ++ src.schema).create,
ts += (101, "A"),
ts.map(_.ins) ++= Seq((102, "B"), (103, "C")),
ts += (101, "A", 1, false, "S1", "S2", 0),
ts.map(_.ins) ++= Seq((102, "B", 1, false, "S1", "S2", 0), (103, "C", 1, false, "S1", "S2", 0)),
ts.filter(_.id > 100).length.result.map(_ shouldBe 0),
ifCap(jcap.forceInsert)(seq(
ts.forceInsert(104, "A"),
ts.map(_.ins).forceInsertAll(Seq((105, "B"), (106, "C"))),
ts.forceInsert(104, "A", 1, false, "S1", "S2", 0),
ts.map(_.ins).forceInsertAll(Seq((105, "B", 1, false, "S1", "S2", 0), (106, "C", 1, false, "S1", "S2", 0))),
ts.filter(_.id > 100).length.result.map(_ shouldBe 3),
ts.map(_.ins).forceInsertAll(Seq((111, "D"))),
ts.map(_.ins).forceInsertAll(Seq((111, "D", 1, false, "S1", "S2", 0))),
ts.filter(_.id > 100).length.result.map(_ shouldBe 4),
src.forceInsert(90, "X"),
ts.forceInsertQuery(src).map(_ shouldBe 1),
ts.filter(_.id.between(90, 99)).map(_.name).result.map(_ shouldBe Seq("X"))
src.forceInsert(90, "X", 1, false, "S1", "S2", 0),
mark("forceInsertQuery", ts.forceInsertQuery(src)).map(_ shouldBe 1),
ts.filter(_.id.between(90, 99)).result.headOption.map(_ shouldBe Some((90, "X", 1, false, "S1", "S2", 0)))
))
)
}
@@ -12,7 +12,10 @@ class RemoveFieldNames(val alwaysKeepSubqueryNames: Boolean = false) extends Pha
def apply(state: CompilerState) = state.map { n => ClientSideOp.mapResultSetMapping(n, true) { rsm =>
val CollectionType(_, NominalType(top, StructType(fdefs))) = rsm.from.nodeType
val indexes = fdefs.iterator.zipWithIndex.map { case ((s, _), i) => (s, ElementSymbol(i+1)) }.toMap
val requiredSyms = rsm.map.collect[TermSymbol]({
case Select(Ref(s), f) if s == rsm.generator => f
}, stopOnMatch = true).toSeq.distinct.zipWithIndex.toMap
logger.debug("Required symbols: " + requiredSyms.mkString(", "))
val rsm2 = rsm.nodeMapServerSide(false, { n =>
val refTSyms = n.collect[TypeSymbol] {
case Select(_ :@ NominalType(s, _), _) => s
@@ -26,15 +29,23 @@ class RemoveFieldNames(val alwaysKeepSubqueryNames: Boolean = false) extends Pha
// Always convert an empty StructNode because there is nothing to reference
(Pure(ProductNode(ConstArray.empty), pts), pts)
case Pure(StructNode(ch), pts) if unrefTSyms contains pts =>
(Pure(if(ch.length == 1 && pts != top) ch(0)._2 else ProductNode(ch.map(_._2)), pts), pts)
val sel =
if(ch.length == 1 && pts != top) ch(0)._2
else if(pts != top) ProductNode(ch.map(_._2))
else ProductNode(ConstArray.from(ch.map { case (s, n) => (requiredSyms.getOrElse(s, Int.MaxValue), n) }.toSeq.sortBy(_._1)).map(_._2))
(Pure(sel, pts), pts)
case Pure(StructNode(ch), pts) if pts == top =>
val sel =
StructNode(ConstArray.from(ch.map { case (s, n) => (requiredSyms.getOrElse(s, Int.MaxValue), (s, n)) }.toSeq.sortBy(_._1)).map(_._2))
(Pure(sel, pts), pts)
}.infer()
})
logger.debug("Transformed RSM: ", rsm2)
val CollectionType(_, fType) = rsm2.from.nodeType
val baseRef = Ref(rsm.generator) :@ fType
rsm2.copy(map = rsm2.map.replace({
case Select(Ref(s), f) if s == rsm.generator =>
Select(baseRef, indexes(f)).infer()
Select(baseRef, ElementSymbol(requiredSyms(f) + 1)).infer()
}, keepType = true)) :@ rsm.nodeType
}}
}

0 comments on commit 55cbecc

Please sign in to comment.