Skip to content

Commit

Permalink
Allow returning of auto-generated keys when inserting data.
Browse files Browse the repository at this point in the history
  • Loading branch information
szeiger committed Aug 9, 2012
1 parent a9bd7fb commit 09a65a8
Show file tree
Hide file tree
Showing 13 changed files with 281 additions and 89 deletions.
8 changes: 8 additions & 0 deletions src/main/scala/scala/slick/driver/AccessDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import java.sql.{Blob, Clob, Date, Time, Timestamp, SQLException}
* <li>Trying to use <code>java.sql.Blob</code> objects causes a NPE in the
* JdbcOdbcDriver. Binary data in the form of <code>Array[Byte]</code> is
* supported.</li>
* <li>Returning columns from an INSERT operation is not supported. Trying
* to execute such an insert statement throws a SlickException.</li>
* </ul>
*
* @author szeiger
Expand All @@ -48,6 +50,7 @@ trait AccessDriver extends ExtendedDriver { driver =>
override val typeMapperDelegates = new TypeMapperDelegates(retryCount)

override def createQueryBuilder(input: QueryBuilderInput): QueryBuilder = new QueryBuilder(input)
override def createInsertBuilder(node: Node): InsertBuilder = new InsertBuilder(node)
override def createTableDDLBuilder(table: Table[_]): TableDDLBuilder = new TableDDLBuilder(table)
override def createColumnDDLBuilder(column: FieldSymbol, table: Table[_]): ColumnDDLBuilder = new ColumnDDLBuilder(column)

Expand Down Expand Up @@ -135,6 +138,11 @@ trait AccessDriver extends ExtendedDriver { driver =>
override protected def buildFetchOffsetClause(fetch: Option[Long], offset: Option[Long]) = ()
}

class InsertBuilder(node: Node) extends super.InsertBuilder(node) {
override def buildReturnColumns(node: Node, table: String): IndexedSeq[FieldSymbol] =
throw new SlickException("Returning columns from INSERT statements is not supported by Access")
}

class TableDDLBuilder(table: Table[_]) extends super.TableDDLBuilder(table) {
override protected def addForeignKey(fk: ForeignKey[_ <: TableNode, _], sb: StringBuilder) {
sb append "CONSTRAINT " append quoteIdentifier(fk.name) append " FOREIGN KEY("
Expand Down
187 changes: 148 additions & 39 deletions src/main/scala/scala/slick/driver/BasicInvokerComponent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package scala.slick.driver

import java.sql.{Statement, PreparedStatement}
import scala.slick.SlickException
import scala.slick.ast.Node
import scala.slick.lifted.{Query, Shape, ShapedValue}
import scala.slick.session.{Session, PositionedParameters, PositionedResult}
import scala.slick.util.RecordLinearizer
import scala.slick.jdbc.{UnitInvokerMixin, MutatingStatementInvoker, MutatingUnitInvoker}
import scala.slick.jdbc.{UnitInvoker, UnitInvokerMixin, MutatingStatementInvoker, MutatingUnitInvoker, ResultSetInvoker}

trait BasicInvokerComponent { driver: BasicDriver =>

Expand All @@ -22,7 +23,7 @@ trait BasicInvokerComponent { driver: BasicDriver =>
def invoker: this.type = this
}

/** Pseudo-invoker for runing DELETE calls. */
/** Pseudo-invoker for running DELETE calls. */
class DeleteInvoker(query: Query[_, _]) {
protected lazy val built = buildDeleteStatement(query)

Expand All @@ -36,68 +37,176 @@ trait BasicInvokerComponent { driver: BasicDriver =>
def deleteInvoker: this.type = this
}

/** Pseudo-invoker for runing INSERT calls. */
class InsertInvoker[U](unpackable: ShapedValue[_, U]) {
lazy val insertStatement = buildInsertStatement(unpackable.value)
def insertStatementFor[TT](query: Query[TT, U]): String = buildInsertStatement(unpackable.value, query).sql
/** Pseudo-invoker for running INSERT calls. */
abstract class InsertInvoker[U](unpackable: ShapedValue[_, U]) {
protected lazy val builder = createInsertBuilder(Node(unpackable.value))

type RetOne
type RetMany

protected def retOne(st: Statement, value: U, updateCount: Int): RetOne
protected def retMany(values: Seq[U], individual: Seq[RetOne]): RetMany
protected def retManyBatch(st: Statement, values: Seq[U], updateCounts: Array[Int]): RetMany

protected lazy val insertResult = builder.buildInsert
lazy val insertStatement = insertResult.sql
def insertStatementFor[TT](query: Query[TT, U]): String = builder.buildInsert(query).sql
def insertStatementFor[TT](c: TT)(implicit shape: Shape[TT, U, _]): String = insertStatementFor(Query(c)(shape))

def useBatchUpdates(implicit session: Session) = session.capabilities.supportsBatchUpdates

/**
* Insert a single row.
*/
def insert(value: U)(implicit session: Session): Int = session.withPreparedStatement(insertStatement) { st =>
protected def prepared[T](sql: String)(f: PreparedStatement => T)(implicit session: Session) =
session.withPreparedStatement(sql)(f)

/** Insert a single row. */
def insert(value: U)(implicit session: Session): RetOne = prepared(insertStatement) { st =>
st.clearParameters()
unpackable.linearizer.narrowedLinearizer.asInstanceOf[RecordLinearizer[U]].setParameter(driver, new PositionedParameters(st), Some(value))
st.executeUpdate()
val count = st.executeUpdate()
retOne(st, value, count)
}

def insertExpr[TT](c: TT)(implicit shape: Shape[TT, U, _], session: Session): Int =
insert(Query(c)(shape))(session)

/**
* Insert multiple rows. Uses JDBC's batch update feature if supported by
* the JDBC driver. Returns Some(rowsAffected), or None if the database
* returned no row count for some part of the batch. If any part of the
* batch fails, an exception thrown.
*/
def insertAll(values: U*)(implicit session: Session): Option[Int] = {
if(!useBatchUpdates || (values.isInstanceOf[IndexedSeq[_]] && values.length < 2))
Some( (0 /: values) { _ + insert(_) } )
else session.withTransaction {
session.withPreparedStatement(insertStatement) { st =>
/** Insert multiple rows. Uses JDBC's batch update feature if supported by
* the JDBC driver. Returns Some(rowsAffected), or None if the database
* returned no row count for some part of the batch. If any part of the
* batch fails, an exception is thrown. */
def insertAll(values: U*)(implicit session: Session): RetMany = session.withTransaction {
if(!useBatchUpdates || (values.isInstanceOf[IndexedSeq[_]] && values.length < 2)) {
retMany(values, values.map(insert))
} else {
prepared(insertStatement) { st =>
st.clearParameters()
for(value <- values) {
unpackable.linearizer.narrowedLinearizer.asInstanceOf[RecordLinearizer[U]].setParameter(driver, new PositionedParameters(st), Some(value))
st.addBatch()
}
var unknown = false
var count = 0
for((res, idx) <- st.executeBatch().zipWithIndex) res match {
case Statement.SUCCESS_NO_INFO => unknown = true
case Statement.EXECUTE_FAILED =>
throw new SlickException("Failed to insert row #" + (idx+1))
case i => count += i
}
if(unknown) None else Some(count)
val counts = st.executeBatch()
retManyBatch(st, values, counts)
}
}
}

def insert[TT](query: Query[TT, U])(implicit session: Session): Int = {
val sbr = buildInsertStatement(unpackable.value, query)
session.withPreparedStatement(insertStatementFor(query)) { st =>
def insertInvoker: this.type = this
}

/** An InsertInvoker that can also insert from another query. */
trait FullInsertInvoker[U] { this: InsertInvoker[U] =>
type RetQuery

protected def retQuery(st: Statement, updateCount: Int): RetQuery

def insertExpr[TT](c: TT)(implicit shape: Shape[TT, U, _], session: Session): RetQuery =
insert(Query(c)(shape))(session)

def insert[TT](query: Query[TT, U])(implicit session: Session): RetQuery = {
val sbr = builder.buildInsert(query)
prepared(insertStatementFor(query)) { st =>
st.clearParameters()
sbr.setter(new PositionedParameters(st), null)
st.executeUpdate()
val count = st.executeUpdate()
retQuery(st, count)
}
}
}

def insertInvoker: this.type = this
/** Pseudo-invoker for running INSERT calls and returning affected row counts. */
class CountingInsertInvoker[U](unpackable: ShapedValue[_, U])
extends InsertInvoker[U](unpackable) with FullInsertInvoker[U] {

type RetOne = Int
type RetMany = Option[Int]
type RetQuery = Int

protected def retOne(st: Statement, value: U, updateCount: Int) = updateCount

protected def retMany(values: Seq[U], individual: Seq[RetOne]) = Some(individual.sum)

protected def retManyBatch(st: Statement, values: Seq[U], updateCounts: Array[Int]) = {
var unknown = false
var count = 0
for((res, idx) <- updateCounts.zipWithIndex) res match {
case Statement.SUCCESS_NO_INFO => unknown = true
case Statement.EXECUTE_FAILED =>
throw new SlickException("Failed to insert row #" + (idx+1))
case i => count += i
}
if(unknown) None else Some(count)
}

protected def retQuery(st: Statement, updateCount: Int) = updateCount

def returning[RT, RU](value: RT)(implicit shape: Shape[RT, RU, _]) =
new KeysInsertInvoker[U, RU](unpackable, new ShapedValue[RT, RU](value, shape))
}

/** Base class with common functionality for KeysInsertInvoker and MappedKeysInsertInvoker. */
abstract class AbstractKeysInsertInvoker[U, RU](unpackable: ShapedValue[_, U], keys: ShapedValue[_, RU])
extends InsertInvoker[U](unpackable) {

protected def buildKeysResult(st: Statement): UnitInvoker[RU] = {
val lin = keys.linearizer.asInstanceOf[RecordLinearizer[RU]]
ResultSetInvoker[RU](_ => st.getGeneratedKeys)(pr => lin.getResult(profile, pr))
}

// Returning keys from batch inserts is generally not supported
override def useBatchUpdates(implicit session: Session) = false

protected lazy val keyColumns =
builder.buildReturnColumns(keys.packedNode, insertResult.table).map(_.name).toArray

override protected def prepared[T](sql: String)(f: PreparedStatement => T)(implicit session: Session) =
session.withPreparedInsertStatement(sql, keyColumns)(f)
}

/** Pseudo-invoker for running INSERT calls and returning generated keys. */
class KeysInsertInvoker[U, RU](unpackable: ShapedValue[_, U], keys: ShapedValue[_, RU])
extends AbstractKeysInsertInvoker[U, RU](unpackable, keys) with FullInsertInvoker[U] {

type RetOne = RU
type RetMany = Seq[RU]
type RetQuery = RetMany

protected def retOne(st: Statement, value: U, updateCount: Int) =
buildKeysResult(st).first()(null)

protected def retMany(values: Seq[U], individual: Seq[RetOne]) = individual

protected def retManyBatch(st: Statement, values: Seq[U], updateCounts: Array[Int]) = {
implicit val session: Session = null
buildKeysResult(st).to[Vector]
}

protected def retQuery(st: Statement, updateCount: Int) = {
implicit val session: Session = null
buildKeysResult(st).to[Vector]
}

def into[R](f: (U, RU) => R) = new MappedKeysInsertInvoker[U, RU, R](unpackable, keys, f)
}

/** Pseudo-invoker for running INSERT calls and returning generated keys combined with the values. */
class MappedKeysInsertInvoker[U, RU, R](unpackable: ShapedValue[_, U], keys: ShapedValue[_, RU],
tr: (U, RU) => R) extends AbstractKeysInsertInvoker[U, RU](unpackable, keys) {

type RetOne = R
type RetMany = Seq[R]

protected def retOne(st: Statement, value: U, updateCount: Int) = {
val ru = buildKeysResult(st).first()(null)
tr(value, ru)
}

protected def retMany(values: Seq[U], individual: Seq[RetOne]) = individual

protected def retManyBatch(st: Statement, values: Seq[U], updateCounts: Array[Int]) = {
implicit val session: Session = null
val ru = buildKeysResult(st).to[Vector]
(values, ru).zipped.map(tr)
}
}

/** Pseudo-invoker for runing UPDATE calls. */
/** Pseudo-invoker for running UPDATE calls. */
class UpdateInvoker[T] (query: Query[_, T]) {
protected lazy val built = buildUpdateStatement(query)

Expand Down
11 changes: 7 additions & 4 deletions src/main/scala/scala/slick/driver/BasicProfile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ trait BasicProfile extends BasicTableComponent { driver: BasicDriver =>
// Create the different builders -- these methods should be overridden by drivers as needed
def createQueryTemplate[P,R](q: Query[_, R]): BasicQueryTemplate[P,R] = new BasicQueryTemplate[P,R](q, this)
def createQueryBuilder(input: QueryBuilderInput): QueryBuilder = new QueryBuilder(input)
def createInsertBuilder(node: Node): InsertBuilder = new InsertBuilder(node)
def createTableDDLBuilder(table: Table[_]): TableDDLBuilder = new TableDDLBuilder(table)
def createColumnDDLBuilder(column: FieldSymbol, table: Table[_]): ColumnDDLBuilder = new ColumnDDLBuilder(column)
def createSequenceDDLBuilder(seq: Sequence[_]): SequenceDDLBuilder = new SequenceDDLBuilder(seq)
Expand All @@ -22,8 +23,10 @@ trait BasicProfile extends BasicTableComponent { driver: BasicDriver =>
final def buildSelectStatement(q: Query[_, _]): QueryBuilderResult = createQueryBuilder(q).buildSelect
final def buildUpdateStatement(q: Query[_, _]): QueryBuilderResult = createQueryBuilder(q).buildUpdate
final def buildDeleteStatement(q: Query[_, _]): QueryBuilderResult = createQueryBuilder(q).buildDelete
final def buildInsertStatement(cb: Any): String = new InsertBuilder(cb).buildInsert
final def buildInsertStatement(cb: Any, q: Query[_, _]): QueryBuilderResult = new InsertBuilder(cb).buildInsert(q)
@deprecated("Use createInsertBuilder.buildInsert", "1.0")
final def buildInsertStatement(cb: Any): InsertBuilderResult = createInsertBuilder(Node(cb)).buildInsert
@deprecated("Use createInsertBuilder.buildInsert", "1.0")
final def buildInsertStatement(cb: Any, q: Query[_, _]): InsertBuilderResult = createInsertBuilder(Node(cb)).buildInsert(q)
final def buildTableDDL(table: Table[_]): DDL = createTableDDLBuilder(table).buildDDL
final def buildSequenceDDL(seq: Sequence[_]): DDL = createSequenceDDLBuilder(seq).buildDDL

Expand All @@ -35,8 +38,8 @@ trait BasicProfile extends BasicTableComponent { driver: BasicDriver =>
implicit def columnToOrdered[T](c: Column[T]): ColumnOrdered[T] = c.asc
implicit def queryToQueryInvoker[T, U](q: Query[T, _ <: U]): QueryInvoker[T, U] = new QueryInvoker(q)
implicit def queryToDeleteInvoker(q: Query[_ <: Table[_], _]): DeleteInvoker = new DeleteInvoker(q)
implicit def columnBaseToInsertInvoker[T](c: ColumnBase[T]) = new InsertInvoker(ShapedValue.createShapedValue(c))
implicit def shapedValueToInsertInvoker[T, U](u: ShapedValue[T, U]) = new InsertInvoker(u)
implicit def columnBaseToInsertInvoker[T](c: ColumnBase[T]) = new CountingInsertInvoker(ShapedValue.createShapedValue(c))
implicit def shapedValueToInsertInvoker[T, U](u: ShapedValue[T, U]) = new CountingInsertInvoker(u)

implicit def queryToQueryExecutor[E, U](q: Query[E, U]): QueryExecutor[Seq[U]] = new QueryExecutor[Seq[U]](new QueryBuilderInput(compiler.run(Node(q)), q))

Expand Down
Loading

1 comment on commit 09a65a8

@ijuma
Copy link
Contributor

@ijuma ijuma commented on 09a65a8 Aug 9, 2012

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. :)

Please sign in to comment.