Skip to content

Commit

Permalink
Merge pull request #782 from slick/tmp/insert-or-update
Browse files Browse the repository at this point in the history
Insert or Update
  • Loading branch information
szeiger committed May 13, 2014
2 parents f59b15c + 4fa915c commit ba2c47a
Show file tree
Hide file tree
Showing 31 changed files with 892 additions and 460 deletions.
Expand Up @@ -20,8 +20,9 @@ class InsertTest extends TestkitTest[JdbcTestDB] {
val src1 = TableQuery(new TestTable(_, "src1_q"))
val dst1 = TableQuery(new TestTable(_, "dst1_q"))
val dst2 = TableQuery(new TestTable(_, "dst2_q"))
val dst3 = TableQuery(new TestTable(_, "dst3_q"))

(src1.ddl ++ dst1.ddl ++ dst2.ddl).create
(src1.ddl ++ dst1.ddl ++ dst2.ddl ++ dst3.ddl).create

src1.insert(1, "A")
src1.map(_.ins).insertAll((2, "B"), (3, "C"))
Expand All @@ -39,6 +40,11 @@ class InsertTest extends TestkitTest[JdbcTestDB] {
dst2.insertExpr(q3)
assertEquals(Set((1,"A"), (2,"B"), (42,"X")), dst2.list.toSet)

val q4comp = Compiled { dst2.filter(_.id < 10) }
val dst3comp = Compiled { dst3 }
dst3comp.insert(q4comp)
assertEquals(Set((1,"A"), (2,"B")), dst3comp.run.toSet)

/*val q4 = (43, "Y".bind)
println("Insert 4: "+Dst2.shaped.insertStatementFor(q4))
Dst2.shaped.insertExpr(q4)
Expand Down Expand Up @@ -108,4 +114,45 @@ class InsertTest extends TestkitTest[JdbcTestDB] {
assertEquals(3, ts.filter(_.id > 100).length.run)
}
}

def testInsertOrUpdatePlain {
class T(tag: Tag) extends Table[(Int, String)](tag, "t_merge") {
def id = column[Int]("id", O.PrimaryKey)
def name = column[String]("name")
def * = (id, name)
def ins = (id, name)
}
val ts = TableQuery[T]

ts.ddl.create

ts ++= Seq((1, "a"), (2, "b"))
assertEquals(1, ts.insertOrUpdate((3, "c")))
assertEquals(1, ts.insertOrUpdate((1, "d")))
assertEquals(Seq((1, "d"), (2, "b"), (3, "c")), ts.sortBy(_.id).run)
}

def testInsertOrUpdateAutoInc {
class T(tag: Tag) extends Table[(Int, String)](tag, "t_merge2") {
def id = column[Int]("id", O.AutoInc, O.PrimaryKey)
def name = column[String]("name")
def * = (id, name)
def ins = (id, name)
}
val ts = TableQuery[T]

ts.ddl.create

ts ++= Seq((1, "a"), (2, "b"))
assertEquals(1, ts.insertOrUpdate((0, "c")))
assertEquals(1, ts.insertOrUpdate((1, "d")))
assertEquals(Seq((1, "d"), (2, "b"), (3, "c")), ts.sortBy(_.id).run)

ifCap(jcap.returnInsertKey) {
val q = ts returning ts.map(_.id)
assertEquals(Some(4), q.insertOrUpdate((0, "e")))
assertEquals(None, q.insertOrUpdate((1, "f")))
assertEquals(Seq((1, "f"), (2, "b"), (3, "c"), (4, "e")), ts.sortBy(_.id).run)
}
}
}
Expand Up @@ -72,7 +72,8 @@ class TemplateTest extends TestkitTest[RelationalTestDB] {
}
def ts = TableQuery[T]
ts.ddl.create
ts ++= Seq((1, "a"), (2, "b"), (3, "c"))
Compiled(ts.map(identity)) += (1, "a")
Compiled(ts) ++= Seq((2, "b"), (3, "c"))

val byIdAndS = { (id: Column[Int], s: ConstColumn[String]) => ts.filter(t => t.id === id && t.s === s) }
val byIdAndSC = Compiled(byIdAndS)
Expand Down Expand Up @@ -105,5 +106,8 @@ class TemplateTest extends TestkitTest[RelationalTestDB] {
val r5 = countBelowC(3).run
val r5t: Int = r5
assertEquals(2, r5)

val joinC = Compiled { id: Column[Int] => ts.filter(_.id === id).innerJoin(ts.filter(_.id === id)) }
assertEquals(Seq(((1, "a"), (1, "a"))), joinC(1).run)
}
}
Expand Up @@ -84,7 +84,7 @@ class Db1 extends {
import scala.slick.ast._
node match{
case TableExpansion(generator, tableNode, columns) => tableName(tableNode)
case TableNode(schemaName, tableName, identity, driverTable) => tableName
case TableNode(schemaName, tableName, identity, driverTable, _) => tableName
}
}

Expand Down
25 changes: 13 additions & 12 deletions src/main/scala/scala/slick/ast/Insert.scala
@@ -1,25 +1,26 @@
package scala.slick.ast

/** Represents an Insert operation. */
final case class Insert(generator: Symbol, table: Node, map: Node, linear: Node) extends Node with DefNode {
final case class Insert(tableSym: Symbol, table: Node, linear: Node) extends BinaryNode with DefNode {
type Self = Insert
def left = table
def right = map
override def nodeChildNames = Vector("table", "map", "linear")
def nodeChildren = Vector(table, map, linear)
def nodeGenerators = Vector((generator, table))
def nodeRebuild(ch: IndexedSeq[Node]) = copy(table = ch(0), map = ch(1), linear = ch(2))
def nodeRebuildWithGenerators(gen: IndexedSeq[Symbol]) = copy(generator = gen(0))
def right = linear
override def nodeChildNames = Vector("table "+tableSym, "linear")
def nodeGenerators = Vector((tableSym, table))
def nodeRebuild(l: Node, r: Node) = copy(table = l, linear = r)
def nodeRebuildWithGenerators(gen: IndexedSeq[Symbol]) = copy(tableSym = gen(0))
def nodeWithComputedType2(scope: SymbolScope, typeChildren: Boolean, retype: Boolean): Self = {
val table2 = table.nodeWithComputedType(scope, typeChildren, retype)
val map2 = map.nodeWithComputedType(scope + (generator -> table2.nodeType), typeChildren, retype)
nodeRebuildOrThis(Vector(table2, map2, linear)).nodeTypedOrCopy(if(!nodeHasType || retype) map2.nodeType else nodeType)
val lin2 = linear.nodeWithComputedType(scope + (tableSym -> table2.nodeType), typeChildren, retype)
nodeRebuildOrThis(Vector(table2, lin2)).nodeTypedOrCopy(if(!nodeHasType || retype) lin2.nodeType else nodeType)
}
override def toString = "Insert"
}

/** A column in an Insert operation. */
final case class InsertColumn(child: Node, fs: FieldSymbol) extends UnaryNode with SimplyTypedNode {
final case class InsertColumn(children: IndexedSeq[Node], fs: FieldSymbol, tpe: Type) extends Node with TypedNode {
def nodeChildren = children
type Self = InsertColumn
def nodeRebuild(ch: Node) = copy(child = ch)
def buildType = child.nodeType
protected[this] def nodeRebuild(ch: IndexedSeq[Node]) = copy(children = ch)
override def toString = s"InsertColumn $fs"
}
2 changes: 1 addition & 1 deletion src/main/scala/scala/slick/ast/Node.scala
Expand Up @@ -517,7 +517,7 @@ object FwdPath {
}

/** A Node representing a database table. */
final case class TableNode(schemaName: Option[String], tableName: String, identity: TableIdentitySymbol, driverTable: Any) extends NullaryNode with TypedNode {
final case class TableNode(schemaName: Option[String], tableName: String, identity: TableIdentitySymbol, driverTable: Any, baseIdentity: TableIdentitySymbol) extends NullaryNode with TypedNode {
type Self = TableNode
def tpe = CollectionType(TypedCollectionTypeConstructor.seq, NominalType(identity)(UnassignedStructuralType(identity)))
def nodeRebuild = copy()
Expand Down
37 changes: 23 additions & 14 deletions src/main/scala/scala/slick/compiler/CodeGen.scala
Expand Up @@ -4,27 +4,36 @@ import slick.ast.{ClientSideOp, CompiledStatement, ResultSetMapping, Node, First
import slick.util.SlickLogger
import org.slf4j.LoggerFactory

/** The code generator phase. The actual implementation is provided
* by a driver. */
/** A standard skeleton for a code generator phase. */
abstract class CodeGen extends Phase {
val name = "codeGen"

override protected[this] lazy val logger = new SlickLogger(LoggerFactory.getLogger(classOf[CodeGen]))
}

object CodeGen {
def apply(f: () => ((Node, CompilerState) => (String, Any))): CodeGen = new CodeGen {
def apply(state: CompilerState): CompilerState = state.map(n => apply(n, state))
def apply(node: Node, state: CompilerState): Node = node match {
case c: ClientSideOp =>
ClientSideOp.mapServerSide(c)(ch => apply(ch, state))
case n =>
val (st, ex) = buildStatement(n, state)
CompiledStatement(st, ex, n.nodeType)
def apply(state: CompilerState): CompilerState = state.map(n => apply(n, state))

def apply(node: Node, state: CompilerState): Node =
ClientSideOp.mapResultSetMapping(node, keepType = true) { rsm =>
var nmap: Option[Node] = None
var compileMap: Option[Node] = Some(rsm.map)

val nfrom = ClientSideOp.mapServerSide(rsm.from, keepType = true) { ss =>
val (nss, nmapOpt) = compileServerSideAndMapping(ss, compileMap, state)
nmapOpt match {
case Some(_) =>
nmap = nmapOpt
compileMap = None
case None =>
}
nss
}
rsm.copy(from = nfrom, map = nmap.get).nodeTyped(rsm.nodeType)
}
def buildStatement(n: Node, state: CompilerState): (String, Any) = f().apply(n, state)
}

def compileServerSideAndMapping(serverSide: Node, mapping: Option[Node], state: CompilerState): (Node, Option[Node])
}

object CodeGen {
def findResult(n: Node): (String, Any) = n match {
case r @ ResultSetMapping(_, from, _) => findResult(from)
case f @ First(from) => findResult(from)
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/scala/slick/compiler/Columnizer.scala
Expand Up @@ -18,7 +18,7 @@ class ExpandTables extends Phase {
if(tsyms.isEmpty) tree else {
// Find the corresponding TableExpansions
val tables: Map[TableIdentitySymbol, (Symbol, Node)] = tree.collect {
case TableExpansion(s, TableNode(_, _, ts, _), ex) if tsyms contains ts => ts -> (s, ex)
case TableExpansion(s, TableNode(_, _, ts, _, _), ex) if tsyms contains ts => ts -> (s, ex)
}.toMap
logger.debug("Table expansions: " + tables.mkString(", "))
// Create a mapping that expands the tables
Expand Down
44 changes: 31 additions & 13 deletions src/main/scala/scala/slick/compiler/InsertCompiler.scala
Expand Up @@ -9,15 +9,15 @@ import Util._

/** A custom compiler for INSERT statements. We could reuse the standard
* phases with a minor modification instead, but this is much faster. */
trait InsertCompiler extends Phase {
class InsertCompiler(val mode: InsertCompiler.Mode) extends Phase {
val name = "insertCompiler"

override protected[this] lazy val logger = new SlickLogger(LoggerFactory.getLogger(classOf[CodeGen]))

def apply(state: CompilerState) = state.map { tree =>
val gen, rgen = new AnonSymbol
val tref = Ref(gen)
val rref = Ref(rgen)
val tableSym, linearSym = new AnonSymbol
val tref = Ref(tableSym)
val rref = Ref(linearSym)

var tableExpansion: TableExpansion = null
var expansionRef: Symbol = null
Expand All @@ -36,23 +36,41 @@ trait InsertCompiler extends Phase {
setTable(te)
tr(expansion)
case sel @ Select(Ref(s), fs: FieldSymbol) if s == expansionRef =>
cols += Select(tref, fs).nodeTyped(sel.nodeType)
InsertColumn(Select(rref, ElementSymbol(cols.size)).nodeTyped(sel.nodeType), fs).nodeTyped(sel.nodeType)
val ch =
if(mode(fs)) {
cols += Select(tref, fs).nodeTyped(sel.nodeType)
IndexedSeq(Select(rref, ElementSymbol(cols.size)).nodeTyped(sel.nodeType))
} else IndexedSeq.empty[Node]
InsertColumn(ch, fs, sel.nodeType)
case Ref(s) if s == expansionRef =>
tr(tableExpansion.columns)
case Bind(gen, te @ TableExpansion(_, t: TableNode, _), Pure(sel, _)) =>
setTable(te)
tr(sel.replace({ case Ref(s) if s == gen => Ref(expansionRef) }, keepType = true))
case _ => throw new SlickException("Cannot use node "+n+" for inserting data")
}
val tree2 = tr(tree)
val tree2 = tr(tree).nodeWithComputedType()
if(tableExpansion eq null) throw new SlickException("No table to insert into")
val ins = Insert(gen, tableExpansion.table, tree2, ProductNode(cols)).nodeWithComputedType(SymbolScope.empty, typeChildren = false, retype = true)
logger.debug("Insert node:", ins)

ResultSetMapping(rgen, ins, createMapping(ins)).nodeTyped(
CollectionType(TypedCollectionTypeConstructor.seq, ins.nodeType))
val ins = Insert(tableSym, tableExpansion.table, ProductNode(cols)).nodeWithComputedType(retype = true)
ResultSetMapping(linearSym, ins, tree2).nodeTyped(CollectionType(TypedCollectionTypeConstructor.seq, ins.nodeType))
}
}

object InsertCompiler {
/** Determines which columns to include in the `Insert` and mapping nodes
* created by `InsertCompiler`. */
trait Mode extends (FieldSymbol => Boolean)

def createMapping(ins: Insert): Node
/** Include all columns. For use in forced inserts and merges. */
case object AllColumns extends Mode {
def apply(fs: FieldSymbol) = true
}
/** Include only non-AutoInc columns. For use in standard (soft) inserts. */
case object NonAutoInc extends Mode {
def apply(fs: FieldSymbol) = !fs.options.contains(ColumnOption.AutoInc)
}
/** Include only primary keys. For use in the insertOrUpdate emulation. */
case object PrimaryKeys extends Mode {
def apply(fs: FieldSymbol) = fs.options.contains(ColumnOption.PrimaryKey)
}
}
3 changes: 2 additions & 1 deletion src/main/scala/scala/slick/direct/SlickBackend.scala
Expand Up @@ -209,7 +209,8 @@ class SlickBackend( val driver: JdbcDriver, mapper:Mapper ) extends QueryableBac
def typetagToQuery(typetag:TypeTag[_]) : Query = {
def _fields = getConstructorArgs(typetag.tpe)
val tableName = mapper.typeToTable( typetag.tpe )
val table = sq.TableNode(None, tableName, sq.SimpleTableIdentitySymbol(driver, "_", tableName), null)
val tsym = sq.SimpleTableIdentitySymbol(driver, "_", tableName)
val table = sq.TableNode(None, tableName, tsym, null, tsym)
val tableRef = new sq.AnonSymbol
val tableExp = sq.TableExpansion(tableRef, table, sq.TypeMapping(
sq.ProductNode( _fields.map( fieldSym => columnSelect(fieldSym, sq.Ref(tableRef)) )),
Expand Down
6 changes: 5 additions & 1 deletion src/main/scala/scala/slick/driver/AccessDriver.scala
Expand Up @@ -60,6 +60,9 @@ import java.sql.{Blob, Clob, Date, Time, Timestamp, SQLException, PreparedStatem
* <li>[[scala.slick.profile.RelationalProfile.capabilities.joinFull]]:
* Full outer joins are emulated because there is not native support
* for them.</li>
* <li>[[scala.slick.driver.JdbcProfile.capabilities.insertOrUpdate]]:
* InsertOrUpdate operations are emulated on the client side because there
* is no native support for them.</li>
* </ul>
*/
trait AccessDriver extends JdbcDriver { driver =>
Expand All @@ -82,7 +85,8 @@ trait AccessDriver extends JdbcDriver { driver =>
- RelationalProfile.capabilities.zip
- JdbcProfile.capabilities.createModel
- RelationalProfile.capabilities.joinFull
)
- JdbcProfile.capabilities.insertOrUpdate
)

def integralTypes = Set(
java.sql.Types.INTEGER,
Expand Down
9 changes: 8 additions & 1 deletion src/main/scala/scala/slick/driver/DerbyDriver.scala
Expand Up @@ -45,6 +45,11 @@ import scala.slick.jdbc.{Invoker, JdbcType}
* <li>[[scala.slick.profile.RelationalProfile.capabilities.joinFull]]:
* Full outer joins are emulated because there is not native support
* for them.</li>
* <li>[[scala.slick.driver.JdbcProfile.capabilities.insertOrUpdate]]:
* InsertOrUpdate operations are emulated on the client side because there
* is no native support for them (but there is work in progress: see
* <a href="https://issues.apache.org/jira/browse/DERBY-3155"
* target="_parent" >DERBY-3155</a>).</li>
* </ul>
*/
trait DerbyDriver extends JdbcDriver { driver =>
Expand All @@ -58,6 +63,7 @@ trait DerbyDriver extends JdbcDriver { driver =>
- SqlProfile.capabilities.sequenceCycle
- RelationalProfile.capabilities.zip
- RelationalProfile.capabilities.joinFull
- JdbcProfile.capabilities.insertOrUpdate
)

override def getTables: Invoker[MTable] = MTable.getTables(None, None, None, Some(Seq("TABLE")))
Expand All @@ -76,8 +82,9 @@ trait DerbyDriver extends JdbcDriver { driver =>
case _ => super.defaultSqlTypeName(tmd)
}

override protected val scalarFrom = Some("sysibm.sysdummy1")

class QueryBuilder(tree: Node, state: CompilerState) extends super.QueryBuilder(tree, state) {
override protected val scalarFrom = Some("sysibm.sysdummy1")
override protected val supportsTuples = false

override def expr(c: Node, skipParens: Boolean = false): Unit = c match {
Expand Down
19 changes: 19 additions & 0 deletions src/main/scala/scala/slick/driver/H2Driver.scala
Expand Up @@ -22,6 +22,10 @@ import scala.slick.jdbc.JdbcType
* <li>[[scala.slick.profile.RelationalProfile.capabilities.joinFull]]:
* Full outer joins are emulated because there is not native support
* for them.</li>
* <li>[[scala.slick.driver.JdbcProfile.capabilities.insertOrUpdate]]:
* InsertOrUpdate operations are emulated on the client side if the
* data to insert contains an `AutoInc` fields. Otherwise the operation
* is performmed natively on the server side.</li>
* </ul>
*/
trait H2Driver extends JdbcDriver { driver =>
Expand All @@ -32,9 +36,12 @@ trait H2Driver extends JdbcDriver { driver =>
- SqlProfile.capabilities.sequenceCycle
- JdbcProfile.capabilities.returnInsertOther
- RelationalProfile.capabilities.joinFull
- JdbcProfile.capabilities.insertOrUpdate
)

override def createQueryBuilder(n: Node, state: CompilerState): QueryBuilder = new QueryBuilder(n, state)
override def createUpsertBuilder(node: Insert): InsertBuilder = new UpsertBuilder(node)
override def createCountingInsertInvoker[U](compiled: CompiledInsert) = new CountingInsertInvoker[U](compiled)

override def defaultSqlTypeName(tmd: JdbcType[_]): String = tmd.sqlType match {
case java.sql.Types.VARCHAR => "VARCHAR"
Expand All @@ -57,6 +64,18 @@ trait H2Driver extends JdbcDriver { driver =>
case _ =>
}
}

/* Extending super.InsertBuilder here instead of super.UpsertBuilder. MERGE is almost identical to INSERT on H2. */
class UpsertBuilder(ins: Insert) extends super.InsertBuilder(ins) {
override protected def buildInsertStart = allNames.mkString(s"merge into $tableName (", ",", ") ")
}

class CountingInsertInvoker[U](compiled: CompiledInsert) extends super.CountingInsertInvoker[U](compiled) {
// H2 cannot perform server-side insert-or-update with soft insert semantics. We don't have to do
// the same in ReturningInsertInvoker because H2 does not allow returning non-AutoInc keys anyway.
override protected val useServerSideUpsert = compiled.upsert.fields.forall(fs => !fs.options.contains(ColumnOption.AutoInc))
override protected def useTransactionForUpsert = !useServerSideUpsert
}
}

object H2Driver extends H2Driver

0 comments on commit ba2c47a

Please sign in to comment.