Permalink
Browse files

Allow returning of auto-generated keys when inserting data.

  • Loading branch information...
szeiger committed Aug 9, 2012
1 parent a9bd7fb commit 09a65a8e88a0363412e218dc5c06023b69809649
@@ -34,6 +34,8 @@ import java.sql.{Blob, Clob, Date, Time, Timestamp, SQLException}
* <li>Trying to use <code>java.sql.Blob</code> objects causes a NPE in the
* JdbcOdbcDriver. Binary data in the form of <code>Array[Byte]</code> is
* supported.</li>
+ * <li>Returning columns from an INSERT operation is not supported. Trying
+ * to execute such an insert statement throws a SlickException.</li>
* </ul>
*
* @author szeiger
@@ -48,6 +50,7 @@ trait AccessDriver extends ExtendedDriver { driver =>
override val typeMapperDelegates = new TypeMapperDelegates(retryCount)
override def createQueryBuilder(input: QueryBuilderInput): QueryBuilder = new QueryBuilder(input)
+ override def createInsertBuilder(node: Node): InsertBuilder = new InsertBuilder(node)
override def createTableDDLBuilder(table: Table[_]): TableDDLBuilder = new TableDDLBuilder(table)
override def createColumnDDLBuilder(column: FieldSymbol, table: Table[_]): ColumnDDLBuilder = new ColumnDDLBuilder(column)
@@ -135,6 +138,11 @@ trait AccessDriver extends ExtendedDriver { driver =>
override protected def buildFetchOffsetClause(fetch: Option[Long], offset: Option[Long]) = ()
}
+ class InsertBuilder(node: Node) extends super.InsertBuilder(node) {
+ override def buildReturnColumns(node: Node, table: String): IndexedSeq[FieldSymbol] =
+ throw new SlickException("Returning columns from INSERT statements is not supported by Access")
+ }
+
class TableDDLBuilder(table: Table[_]) extends super.TableDDLBuilder(table) {
override protected def addForeignKey(fk: ForeignKey[_ <: TableNode, _], sb: StringBuilder) {
sb append "CONSTRAINT " append quoteIdentifier(fk.name) append " FOREIGN KEY("
@@ -2,10 +2,11 @@ package scala.slick.driver
import java.sql.{Statement, PreparedStatement}
import scala.slick.SlickException
+import scala.slick.ast.Node
import scala.slick.lifted.{Query, Shape, ShapedValue}
import scala.slick.session.{Session, PositionedParameters, PositionedResult}
import scala.slick.util.RecordLinearizer
-import scala.slick.jdbc.{UnitInvokerMixin, MutatingStatementInvoker, MutatingUnitInvoker}
+import scala.slick.jdbc.{UnitInvoker, UnitInvokerMixin, MutatingStatementInvoker, MutatingUnitInvoker, ResultSetInvoker}
trait BasicInvokerComponent { driver: BasicDriver =>
@@ -22,7 +23,7 @@ trait BasicInvokerComponent { driver: BasicDriver =>
def invoker: this.type = this
}
- /** Pseudo-invoker for runing DELETE calls. */
+ /** Pseudo-invoker for running DELETE calls. */
class DeleteInvoker(query: Query[_, _]) {
protected lazy val built = buildDeleteStatement(query)
@@ -36,68 +37,176 @@ trait BasicInvokerComponent { driver: BasicDriver =>
def deleteInvoker: this.type = this
}
- /** Pseudo-invoker for runing INSERT calls. */
- class InsertInvoker[U](unpackable: ShapedValue[_, U]) {
- lazy val insertStatement = buildInsertStatement(unpackable.value)
- def insertStatementFor[TT](query: Query[TT, U]): String = buildInsertStatement(unpackable.value, query).sql
+ /** Pseudo-invoker for running INSERT calls. */
+ abstract class InsertInvoker[U](unpackable: ShapedValue[_, U]) {
+ protected lazy val builder = createInsertBuilder(Node(unpackable.value))
+
+ type RetOne
+ type RetMany
+
+ protected def retOne(st: Statement, value: U, updateCount: Int): RetOne
+ protected def retMany(values: Seq[U], individual: Seq[RetOne]): RetMany
+ protected def retManyBatch(st: Statement, values: Seq[U], updateCounts: Array[Int]): RetMany
+
+ protected lazy val insertResult = builder.buildInsert
+ lazy val insertStatement = insertResult.sql
+ def insertStatementFor[TT](query: Query[TT, U]): String = builder.buildInsert(query).sql
def insertStatementFor[TT](c: TT)(implicit shape: Shape[TT, U, _]): String = insertStatementFor(Query(c)(shape))
def useBatchUpdates(implicit session: Session) = session.capabilities.supportsBatchUpdates
- /**
- * Insert a single row.
- */
- def insert(value: U)(implicit session: Session): Int = session.withPreparedStatement(insertStatement) { st =>
+ protected def prepared[T](sql: String)(f: PreparedStatement => T)(implicit session: Session) =
+ session.withPreparedStatement(sql)(f)
+
+ /** Insert a single row. */
+ def insert(value: U)(implicit session: Session): RetOne = prepared(insertStatement) { st =>
st.clearParameters()
unpackable.linearizer.narrowedLinearizer.asInstanceOf[RecordLinearizer[U]].setParameter(driver, new PositionedParameters(st), Some(value))
- st.executeUpdate()
+ val count = st.executeUpdate()
+ retOne(st, value, count)
}
- def insertExpr[TT](c: TT)(implicit shape: Shape[TT, U, _], session: Session): Int =
- insert(Query(c)(shape))(session)
-
- /**
- * Insert multiple rows. Uses JDBC's batch update feature if supported by
- * the JDBC driver. Returns Some(rowsAffected), or None if the database
- * returned no row count for some part of the batch. If any part of the
- * batch fails, an exception thrown.
- */
- def insertAll(values: U*)(implicit session: Session): Option[Int] = {
- if(!useBatchUpdates || (values.isInstanceOf[IndexedSeq[_]] && values.length < 2))
- Some( (0 /: values) { _ + insert(_) } )
- else session.withTransaction {
- session.withPreparedStatement(insertStatement) { st =>
+ /** Insert multiple rows. Uses JDBC's batch update feature if supported by
+ * the JDBC driver. Returns Some(rowsAffected), or None if the database
+ * returned no row count for some part of the batch. If any part of the
+ * batch fails, an exception is thrown. */
+ def insertAll(values: U*)(implicit session: Session): RetMany = session.withTransaction {
+ if(!useBatchUpdates || (values.isInstanceOf[IndexedSeq[_]] && values.length < 2)) {
+ retMany(values, values.map(insert))
+ } else {
+ prepared(insertStatement) { st =>
st.clearParameters()
for(value <- values) {
unpackable.linearizer.narrowedLinearizer.asInstanceOf[RecordLinearizer[U]].setParameter(driver, new PositionedParameters(st), Some(value))
st.addBatch()
}
var unknown = false
- var count = 0
- for((res, idx) <- st.executeBatch().zipWithIndex) res match {
- case Statement.SUCCESS_NO_INFO => unknown = true
- case Statement.EXECUTE_FAILED =>
- throw new SlickException("Failed to insert row #" + (idx+1))
- case i => count += i
- }
- if(unknown) None else Some(count)
+ val counts = st.executeBatch()
+ retManyBatch(st, values, counts)
}
}
}
- def insert[TT](query: Query[TT, U])(implicit session: Session): Int = {
- val sbr = buildInsertStatement(unpackable.value, query)
- session.withPreparedStatement(insertStatementFor(query)) { st =>
+ def insertInvoker: this.type = this
+ }
+
+ /** An InsertInvoker that can also insert from another query. */
+ trait FullInsertInvoker[U] { this: InsertInvoker[U] =>
+ type RetQuery
+
+ protected def retQuery(st: Statement, updateCount: Int): RetQuery
+
+ def insertExpr[TT](c: TT)(implicit shape: Shape[TT, U, _], session: Session): RetQuery =
+ insert(Query(c)(shape))(session)
+
+ def insert[TT](query: Query[TT, U])(implicit session: Session): RetQuery = {
+ val sbr = builder.buildInsert(query)
+ prepared(insertStatementFor(query)) { st =>
st.clearParameters()
sbr.setter(new PositionedParameters(st), null)
- st.executeUpdate()
+ val count = st.executeUpdate()
+ retQuery(st, count)
}
}
+ }
- def insertInvoker: this.type = this
+ /** Pseudo-invoker for running INSERT calls and returning affected row counts. */
+ class CountingInsertInvoker[U](unpackable: ShapedValue[_, U])
+ extends InsertInvoker[U](unpackable) with FullInsertInvoker[U] {
+
+ type RetOne = Int
+ type RetMany = Option[Int]
+ type RetQuery = Int
+
+ protected def retOne(st: Statement, value: U, updateCount: Int) = updateCount
+
+ protected def retMany(values: Seq[U], individual: Seq[RetOne]) = Some(individual.sum)
+
+ protected def retManyBatch(st: Statement, values: Seq[U], updateCounts: Array[Int]) = {
+ var unknown = false
+ var count = 0
+ for((res, idx) <- updateCounts.zipWithIndex) res match {
+ case Statement.SUCCESS_NO_INFO => unknown = true
+ case Statement.EXECUTE_FAILED =>
+ throw new SlickException("Failed to insert row #" + (idx+1))
+ case i => count += i
+ }
+ if(unknown) None else Some(count)
+ }
+
+ protected def retQuery(st: Statement, updateCount: Int) = updateCount
+
+ def returning[RT, RU](value: RT)(implicit shape: Shape[RT, RU, _]) =
+ new KeysInsertInvoker[U, RU](unpackable, new ShapedValue[RT, RU](value, shape))
+ }
+
+ /** Base class with common functionality for KeysInsertInvoker and MappedKeysInsertInvoker. */
+ abstract class AbstractKeysInsertInvoker[U, RU](unpackable: ShapedValue[_, U], keys: ShapedValue[_, RU])
+ extends InsertInvoker[U](unpackable) {
+
+ protected def buildKeysResult(st: Statement): UnitInvoker[RU] = {
+ val lin = keys.linearizer.asInstanceOf[RecordLinearizer[RU]]
+ ResultSetInvoker[RU](_ => st.getGeneratedKeys)(pr => lin.getResult(profile, pr))
+ }
+
+ // Returning keys from batch inserts is generally not supported
+ override def useBatchUpdates(implicit session: Session) = false
+
+ protected lazy val keyColumns =
+ builder.buildReturnColumns(keys.packedNode, insertResult.table).map(_.name).toArray
+
+ override protected def prepared[T](sql: String)(f: PreparedStatement => T)(implicit session: Session) =
+ session.withPreparedInsertStatement(sql, keyColumns)(f)
+ }
+
+ /** Pseudo-invoker for running INSERT calls and returning generated keys. */
+ class KeysInsertInvoker[U, RU](unpackable: ShapedValue[_, U], keys: ShapedValue[_, RU])
+ extends AbstractKeysInsertInvoker[U, RU](unpackable, keys) with FullInsertInvoker[U] {
+
+ type RetOne = RU
+ type RetMany = Seq[RU]
+ type RetQuery = RetMany
+
+ protected def retOne(st: Statement, value: U, updateCount: Int) =
+ buildKeysResult(st).first()(null)
+
+ protected def retMany(values: Seq[U], individual: Seq[RetOne]) = individual
+
+ protected def retManyBatch(st: Statement, values: Seq[U], updateCounts: Array[Int]) = {
+ implicit val session: Session = null
+ buildKeysResult(st).to[Vector]
+ }
+
+ protected def retQuery(st: Statement, updateCount: Int) = {
+ implicit val session: Session = null
+ buildKeysResult(st).to[Vector]
+ }
+
+ def into[R](f: (U, RU) => R) = new MappedKeysInsertInvoker[U, RU, R](unpackable, keys, f)
+ }
+
+ /** Pseudo-invoker for running INSERT calls and returning generated keys combined with the values. */
+ class MappedKeysInsertInvoker[U, RU, R](unpackable: ShapedValue[_, U], keys: ShapedValue[_, RU],
+ tr: (U, RU) => R) extends AbstractKeysInsertInvoker[U, RU](unpackable, keys) {
+
+ type RetOne = R
+ type RetMany = Seq[R]
+
+ protected def retOne(st: Statement, value: U, updateCount: Int) = {
+ val ru = buildKeysResult(st).first()(null)
+ tr(value, ru)
+ }
+
+ protected def retMany(values: Seq[U], individual: Seq[RetOne]) = individual
+
+ protected def retManyBatch(st: Statement, values: Seq[U], updateCounts: Array[Int]) = {
+ implicit val session: Session = null
+ val ru = buildKeysResult(st).to[Vector]
+ (values, ru).zipped.map(tr)
+ }
}
- /** Pseudo-invoker for runing UPDATE calls. */
+ /** Pseudo-invoker for running UPDATE calls. */
class UpdateInvoker[T] (query: Query[_, T]) {
protected lazy val built = buildUpdateStatement(query)
@@ -10,6 +10,7 @@ trait BasicProfile extends BasicTableComponent { driver: BasicDriver =>
// Create the different builders -- these methods should be overridden by drivers as needed
def createQueryTemplate[P,R](q: Query[_, R]): BasicQueryTemplate[P,R] = new BasicQueryTemplate[P,R](q, this)
def createQueryBuilder(input: QueryBuilderInput): QueryBuilder = new QueryBuilder(input)
+ def createInsertBuilder(node: Node): InsertBuilder = new InsertBuilder(node)
def createTableDDLBuilder(table: Table[_]): TableDDLBuilder = new TableDDLBuilder(table)
def createColumnDDLBuilder(column: FieldSymbol, table: Table[_]): ColumnDDLBuilder = new ColumnDDLBuilder(column)
def createSequenceDDLBuilder(seq: Sequence[_]): SequenceDDLBuilder = new SequenceDDLBuilder(seq)
@@ -22,8 +23,10 @@ trait BasicProfile extends BasicTableComponent { driver: BasicDriver =>
final def buildSelectStatement(q: Query[_, _]): QueryBuilderResult = createQueryBuilder(q).buildSelect
final def buildUpdateStatement(q: Query[_, _]): QueryBuilderResult = createQueryBuilder(q).buildUpdate
final def buildDeleteStatement(q: Query[_, _]): QueryBuilderResult = createQueryBuilder(q).buildDelete
- final def buildInsertStatement(cb: Any): String = new InsertBuilder(cb).buildInsert
- final def buildInsertStatement(cb: Any, q: Query[_, _]): QueryBuilderResult = new InsertBuilder(cb).buildInsert(q)
+ @deprecated("Use createInsertBuilder.buildInsert", "1.0")
+ final def buildInsertStatement(cb: Any): InsertBuilderResult = createInsertBuilder(Node(cb)).buildInsert
+ @deprecated("Use createInsertBuilder.buildInsert", "1.0")
+ final def buildInsertStatement(cb: Any, q: Query[_, _]): InsertBuilderResult = createInsertBuilder(Node(cb)).buildInsert(q)
final def buildTableDDL(table: Table[_]): DDL = createTableDDLBuilder(table).buildDDL
final def buildSequenceDDL(seq: Sequence[_]): DDL = createSequenceDDLBuilder(seq).buildDDL
@@ -35,8 +38,8 @@ trait BasicProfile extends BasicTableComponent { driver: BasicDriver =>
implicit def columnToOrdered[T](c: Column[T]): ColumnOrdered[T] = c.asc
implicit def queryToQueryInvoker[T, U](q: Query[T, _ <: U]): QueryInvoker[T, U] = new QueryInvoker(q)
implicit def queryToDeleteInvoker(q: Query[_ <: Table[_], _]): DeleteInvoker = new DeleteInvoker(q)
- implicit def columnBaseToInsertInvoker[T](c: ColumnBase[T]) = new InsertInvoker(ShapedValue.createShapedValue(c))
- implicit def shapedValueToInsertInvoker[T, U](u: ShapedValue[T, U]) = new InsertInvoker(u)
+ implicit def columnBaseToInsertInvoker[T](c: ColumnBase[T]) = new CountingInsertInvoker(ShapedValue.createShapedValue(c))
+ implicit def shapedValueToInsertInvoker[T, U](u: ShapedValue[T, U]) = new CountingInsertInvoker(u)
implicit def queryToQueryExecutor[E, U](q: Query[E, U]): QueryExecutor[Seq[U]] = new QueryExecutor[Seq[U]](new QueryBuilderInput(compiler.run(Node(q)), q))
Oops, something went wrong.

1 comment on commit 09a65a8

@ijuma

This comment has been minimized.

Show comment Hide comment
@ijuma

ijuma Aug 9, 2012

Contributor

Nice. :)

Contributor

ijuma commented on 09a65a8 Aug 9, 2012

Nice. :)

Please sign in to comment.