Skip to content

Commit

Permalink
SelectForUpdate support
Browse files Browse the repository at this point in the history
  • Loading branch information
smootoo committed Mar 28, 2016
1 parent a63c2ba commit 919609f
Show file tree
Hide file tree
Showing 17 changed files with 178 additions and 11 deletions.
1 change: 1 addition & 0 deletions slick-testkit/src/main/resources/testkit-reference.conf
Expand Up @@ -45,6 +45,7 @@ testkit {
${testPackage}.TemplateTest
${testPackage}.TransactionTest
${testPackage}.UnionTest
${testPackage}.ForUpdateTest
]
}

Expand Down
@@ -0,0 +1,98 @@
package com.typesafe.slick.testkit.tests

import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue, TimeUnit, ThreadPoolExecutor}

import com.typesafe.slick.testkit.util.{TestkitConfig, AsyncTest, JdbcTestDB}
import org.junit.Assert
import slick.dbio.DBIOAction
import slick.jdbc.{SQLServerProfile, TransactionIsolation}
import slick.util.Logging

import scala.concurrent.duration.Duration
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.util.Failure


class ForUpdateTest extends AsyncTest[JdbcTestDB] with Logging {
import tdb.profile.api._

val tableName = "test_forupdate"
class T(tag: Tag) extends Table[(Int, Option[String])](tag, tableName) {
def id = column[Int]("id", O.PrimaryKey)
def data = column[Option[String]]("data")
def * = (id, data)
}
val ts = TableQuery[T]
def testForUpdate: DBIO[Unit] = {
ifCap(jcap.forUpdate) {
val exe = new ThreadPoolExecutor(2, 2, 1L, TimeUnit.SECONDS, new LinkedBlockingQueue[Runnable]())
@volatile var success = true
val childStartLatch = new CountDownLatch(1)
val thread1Latch = new CountDownLatch(1)
val thread2Latch = new CountDownLatch(1)
val rowSelect: (Int) => Query[T, (Int, Option[String]), Seq] = (i: Int) => ts.filter(_.id === i)
def runAsyncCommands[T](commands: DBIO[Unit], latch: CountDownLatch): DBIO[Unit] = {
exe.execute(new Runnable {
override def run(): Unit = Await.result({
childStartLatch.await()
val f = db.run {
seq(GetTransactionality.map(_._1 shouldBe 0), // make sure not in transaction yet
commands.transactionally)
}
f.onComplete { t => {
latch.countDown()
t match {
case Failure(e) => success = false
case _ =>
}
}}
f
}, TestkitConfig.asyncTimeout)
})
DBIOAction.successful(()) // dummy action to add into pipeline
}
seq(
ifCap(tcap.selectForUpdateRowLocking) { // if database is capable of executing a row locking test
for {
_ <- ts.schema.create
_ <- ts ++= Seq((1, None), (2, None))
// start txn for main thread
_ <- (for {
_ <- GetTransactionality.map(_._1 shouldBe 1) // check in main transaction
// locking read on row.id 1
r1 <- rowSelect(1).forUpdate.result
_ = r1 shouldBe Seq((1, None))
_ <- runAsyncCommands(seq(
// this read is free to continue
rowSelect(2).forUpdate.result.map(x => x shouldBe Seq((2, None))),
rowSelect(2).map(_.data).update(Some("Thread 1 update"))
), thread1Latch)
_ <- runAsyncCommands(seq(
// this read blocks on main thread txn,so check the main update happened first once running
rowSelect(1).forUpdate.result.map(_ shouldBe Seq((1, Some("Main thread update")))),
rowSelect(1).map(_.data).update(Some("Thread 2 update"))
), thread2Latch)
_ = childStartLatch.countDown() // start child threads
_ = thread1Latch.await() // wait for thread 1 to finish
_ <- rowSelect(1).map(_.data).update(Some("Main thread update"))
} yield ()).transactionally
_ = thread2Latch.await()
// Thread 2 update should have overwritten main thread update
_ <- ts.result.map(_.toSet shouldBe Set((1, Some("Thread 2 update")), (2, Some("Thread 1 update"))))
_ <- ts.schema.drop
_ = exe.shutdown()
// Fail the test if there were failures in the child threads
_ = Assert.assertTrue(success)
} yield ()
},
ifNotCap(tcap.selectForUpdateRowLocking) { // a simple test to assert the syntax is valid
for {
_ <- ts.schema.create
_ <- ts ++= Seq((1, None), (2, None))
r1 <- rowSelect(1).forUpdate.result
_ = r1 shouldBe Seq((1, None))
} yield ()
})
}
}
}
Expand Up @@ -183,6 +183,9 @@ object StandardTestDBs {
import profile.api.actionBasedSQLInterpolation

val defaultSchema = config.getString("defaultSchema")
// sqlserver has valid "select for update" syntax, but in testing on Appveyor, the test hangs due to lock escalation
// so exclude from explicit ForUpdate testing
override def capabilities = super.capabilities - TestDB.capabilities.selectForUpdateRowLocking

override def localTables(implicit ec: ExecutionContext): DBIO[Vector[String]] =
ResultSetAction[(String,String,String, String)](_.conn.getMetaData().getTables(testDB, defaultSchema, null, null)).map { ts =>
Expand Down Expand Up @@ -292,6 +295,9 @@ object DerbyDB {
abstract class HsqlDB(confName: String) extends InternalJdbcTestDB(confName) {
val profile = HsqldbProfile
val jdbcDriver = "org.hsqldb.jdbcDriver"
// Hsqldb has valid "select for update" syntax, but in testing, it either takes a whole table lock or no exclusive
// lock at all, so exclude from ForUpdate testing
override def capabilities = super.capabilities - TestDB.capabilities.selectForUpdateRowLocking
override def localTables(implicit ec: ExecutionContext): DBIO[Vector[String]] =
ResultSetAction[(String,String,String, String)](_.conn.getMetaData().getTables(null, "PUBLIC", null, null)).map { ts =>
ts.map(_._3).sorted
Expand Down
Expand Up @@ -37,8 +37,11 @@ object TestDB {
val jdbcMetaGetIndexInfo = new Capability("test.jdbcMetaGetIndexInfo")
/** Supports all tested transaction isolation levels */
val transactionIsolation = new Capability("test.transactionIsolation")
/** Supports select for update row locking */
val selectForUpdateRowLocking = new Capability("test.selectForUpdateRowLocking")

val all = Set(plainSql, jdbcMeta, jdbcMetaGetClientInfoProperties, jdbcMetaGetFunctions, jdbcMetaGetIndexInfo, transactionIsolation)
val all = Set(plainSql, jdbcMeta, jdbcMetaGetClientInfoProperties, jdbcMetaGetFunctions, jdbcMetaGetIndexInfo,
transactionIsolation, selectForUpdateRowLocking)
}

/** Copy a file, expanding it if the source name ends with .gz */
Expand Down
4 changes: 3 additions & 1 deletion slick/src/main/scala/slick/ast/Comprehension.scala
Expand Up @@ -9,7 +9,9 @@ final case class Comprehension(sym: TermSymbol, from: Node, select: Node, where:
groupBy: Option[Node] = None, orderBy: ConstArray[(Node, Ordering)] = ConstArray.empty,
having: Option[Node] = None,
distinct: Option[Node] = None,
fetch: Option[Node] = None, offset: Option[Node] = None) extends DefNode {
fetch: Option[Node] = None,
offset: Option[Node] = None,
forUpdate: Boolean = false) extends DefNode {
type Self = Comprehension
lazy val children = (ConstArray.newBuilder() + from + select ++ where ++ groupBy ++ orderBy.map(_._1) ++ having ++ distinct ++ fetch ++ offset).result
override def childNames =
Expand Down
8 changes: 8 additions & 0 deletions slick/src/main/scala/slick/ast/Node.scala
Expand Up @@ -377,6 +377,14 @@ final case class GroupBy(fromGen: TermSymbol, from: Node, by: Node, identity: Ty
}
}

/** A .forUpdate call */
final case class ForUpdate(generator: TermSymbol, from: Node) extends ComplexFilteredQuery {
type Self = ForUpdate
lazy val children = ConstArray(from)
protected[this] def rebuild(ch: ConstArray[Node]) = copy(from = ch(0))
protected[this] def rebuildWithSymbols(gen: ConstArray[TermSymbol]) = copy(generator = gen(0))
}

/** A .take call. */
final case class Take(from: Node, count: Node) extends SimpleFilteredQuery with BinaryNode {
type Self = Take
Expand Down
Expand Up @@ -81,6 +81,13 @@ class MergeToComprehensions extends Phase {
logger.debug("Merged SortBy into Comprehension:", c2)
(c2, replacements1)

case ForUpdate(s1, f1) =>
val (c1, replacements1) = mergeSortBy(f1, true)
logger.debug("Merging ForUpdate into Comprehension:", Ellipsis(n, List(0)))
val c2 = c1.copy(forUpdate = true) :@ c1.nodeType
logger.debug("Merged ForUpdate into Comprehension:", c2)
(c2, replacements1)

case Distinct(s1, f1, o1) =>
val (c1, replacements1) = mergeSortBy(f1, true)
val (c1a, replacements1a) =
Expand Down
2 changes: 1 addition & 1 deletion slick/src/main/scala/slick/compiler/RemoveFieldNames.scala
Expand Up @@ -20,7 +20,7 @@ class RemoveFieldNames(val alwaysKeepSubqueryNames: Boolean = false) extends Pha
val refTSyms = n.collect[TypeSymbol] {
case Select(_ :@ NominalType(s, _), _) => s
case Union(_, _ :@ CollectionType(_, NominalType(s, _)), _) => s
case Comprehension(_, _ :@ CollectionType(_, NominalType(s, _)), _, _, _, _, _, _, _, _) if alwaysKeepSubqueryNames => s
case Comprehension(_, _ :@ CollectionType(_, NominalType(s, _)), _, _, _, _, _, _, _, _, _) if alwaysKeepSubqueryNames => s
}.toSet
val allTSyms = n.collect[TypeSymbol] { case p: Pure => p.identity }.toSet
val unrefTSyms = allTSyms -- refTSyms
Expand Down
2 changes: 1 addition & 1 deletion slick/src/main/scala/slick/compiler/RewriteBooleans.scala
Expand Up @@ -36,7 +36,7 @@ class RewriteBooleans extends Phase {
case Apply(sym, ch) :@ tpe if isBooleanLike(tpe) =>
toFake(Apply(sym, ch)(n.nodeType).infer())
// Where clauses, join conditions and case clauses need real boolean predicates
case n @ Comprehension(_, _, _, where, _, _, having, _, _, _) =>
case n @ Comprehension(_, _, _, where, _, _, having, _, _, _, _) =>
n.copy(where = where.map(toReal), having = having.map(toReal)) :@ n.nodeType
case n @ Join(_, _, _, _, _, on) =>
n.copy(on = toReal(on)) :@ n.nodeType
Expand Down
Expand Up @@ -14,9 +14,9 @@ class SpecializeParameters extends Phase {
state.map(ClientSideOp.mapServerSide(_, keepType = true)(transformServerSide))

def transformServerSide(n: Node): Node = {
val cs = n.collect { case c @ Comprehension(_, _, _, _, _, _, _, _, Some(_: QueryParameter), _) => c }
val cs = n.collect { case c @ Comprehension(_, _, _, _, _, _, _, _, Some(_: QueryParameter), _, _) => c }
logger.debug("Affected fetch clauses in: "+cs.mkString(", "))
cs.foldLeft(n) { case (n, c @ Comprehension(_, _, _, _, _, _, _, _, Some(fetch: QueryParameter), _)) =>
cs.foldLeft(n) { case (n, c @ Comprehension(_, _, _, _, _, _, _, _, Some(fetch: QueryParameter), _, _)) =>
val compiledFetchParam = QueryParameter(fetch.extractor, ScalaBaseType.longType)
val guarded = n.replace({ case c2: Comprehension if c2 == c => c2.copy(fetch = Some(LiteralNode(0L))) }, keepType = true)
val fallback = n.replace({ case c2: Comprehension if c2 == c => c2.copy(fetch = Some(compiledFetchParam)) }, keepType = true)
Expand Down
7 changes: 7 additions & 0 deletions slick/src/main/scala/slick/jdbc/DB2Profile.scala
Expand Up @@ -108,6 +108,13 @@ trait DB2Profile extends JdbcProfile {
expr(n)
if(o.direction.desc) b += " desc"
}

override protected def buildForUpdateClause(forUpdate: Boolean) = {
super.buildForUpdateClause(forUpdate)
if(forUpdate) {
b" with RS "
}
}
}

class TableDDLBuilder(table: Table[_]) extends super.TableDDLBuilder(table) {
Expand Down
7 changes: 7 additions & 0 deletions slick/src/main/scala/slick/jdbc/DerbyProfile.scala
Expand Up @@ -125,6 +125,13 @@ trait DerbyProfile extends JdbcProfile {
override protected val supportsLiteralGroupBy = true
override protected val quotedJdbcFns = Some(Vector(Library.User))

override protected def buildForUpdateClause(forUpdate: Boolean) = {
super.buildForUpdateClause(forUpdate)
if (forUpdate) {
b" with RS "
}
}

override def expr(c: Node, skipParens: Boolean = false): Unit = c match {
case Library.Cast(ch @ _*) =>
/* Work around DERBY-2072 by casting numeric values first to CHAR and
Expand Down
4 changes: 3 additions & 1 deletion slick/src/main/scala/slick/jdbc/JdbcCapabilities.scala
Expand Up @@ -29,10 +29,12 @@ object JdbcCapabilities {
val distinguishesIntTypes = Capability("jdbc.distinguishesIntTypes")
/** Has a datatype directly corresponding to Scala Byte */
val supportsByte = Capability("jdbc.supportsByte")
/** Supports FOR UPDATE row level locking */
val forUpdate = Capability("jdbc.forUpdate")

/** Supports all JdbcProfile features which do not have separate capability values */
val other = Capability("jdbc.other")

/** All JDBC capabilities */
val all = Set(other, createModel, forceInsert, insertOrUpdate, mutable, returnInsertKey, defaultValueMetaData, booleanMetaData, nullableNoDefault, distinguishesIntTypes, supportsByte, returnInsertOther)
val all = Set(other, createModel, forceInsert, insertOrUpdate, mutable, returnInsertKey, defaultValueMetaData, booleanMetaData, nullableNoDefault, distinguishesIntTypes, supportsByte, returnInsertOther, forUpdate)
}
Expand Up @@ -148,6 +148,7 @@ trait JdbcStatementBuilderComponent { self: JdbcProfile =>
buildHavingClause(c.having)
buildOrderByClause(c.orderBy)
if(!limit0) buildFetchOffsetClause(c.fetch, c.offset)
buildForUpdateClause(c.forUpdate)
currentUniqueFrom = oldUniqueFrom
}

Expand Down Expand Up @@ -248,6 +249,12 @@ trait JdbcStatementBuilderComponent { self: JdbcProfile =>
}
}

protected def buildForUpdateClause(forUpdate: Boolean) = building(OtherPart) {
if(forUpdate) {
b"\nfor update "
}
}

protected def buildSelectPart(n: Node): Unit = n match {
case c: Comprehension =>
b"\["
Expand Down Expand Up @@ -437,7 +444,7 @@ trait JdbcStatementBuilderComponent { self: JdbcProfile =>

def buildUpdate: SQLBuilder.Result = {
val (gen, from, where, select) = tree match {
case Comprehension(sym, from: TableNode, Pure(select, _), where, None, _, None, None, None, None) => select match {
case Comprehension(sym, from: TableNode, Pure(select, _), where, None, _, None, None, None, None, false) => select match {
case f @ Select(Ref(struct), _) if struct == sym => (sym, from, where, ConstArray(f.field))
case ProductNode(ch) if ch.forall{ case Select(Ref(struct), _) if struct == sym => true; case _ => false} =>
(sym, from, where, ch.map{ case Select(Ref(_), field) => field })
Expand All @@ -461,9 +468,9 @@ trait JdbcStatementBuilderComponent { self: JdbcProfile =>
def fail(msg: String) =
throw new SlickException("Invalid query for DELETE statement: " + msg)
val (gen, from, where) = tree match {
case Comprehension(sym, from, Pure(select, _), where, _, _, None, distinct, fetch, offset) =>
if(fetch.isDefined || offset.isDefined || distinct.isDefined)
fail(".take, .drop and .distinct are not supported for deleting")
case Comprehension(sym, from, Pure(select, _), where, _, _, None, distinct, fetch, offset, forUpdate) =>
if(fetch.isDefined || offset.isDefined || distinct.isDefined || forUpdate)
fail(".take, .drop .forUpdate and .distinct are not supported for deleting")
from match {
case from: TableNode => (sym, from, where)
case from => fail("A single source table is required, found: "+from)
Expand Down
13 changes: 13 additions & 0 deletions slick/src/main/scala/slick/jdbc/SQLServerProfile.scala
Expand Up @@ -153,6 +153,19 @@ trait SQLServerProfile extends JdbcProfile {
if(o.direction.desc) b" desc"
}

override protected def buildFromClause(from: Seq[(TermSymbol, Node)]) = {
super.buildFromClause(from)
tree match {
// SQL Server "select for update" syntax
case c: Comprehension => if(c.forUpdate) b" with (updlock,rowlock) "
case _ =>
}
}

override protected def buildForUpdateClause(forUpdate: Boolean) = {
// SQLSever doesn't have "select for update" syntax, so use with (updlock,rowlock) in from clause
}

override def expr(n: Node, skipParens: Boolean = false): Unit = n match {
// Cast bind variables of type TIME to TIME (otherwise they're treated as TIMESTAMP)
case c @ LiteralNode(v) if c.volatileHint && jdbcTypeFor(c.nodeType) == columnTypes.timeJdbcType =>
Expand Down
1 change: 1 addition & 0 deletions slick/src/main/scala/slick/jdbc/SQLiteProfile.scala
Expand Up @@ -91,6 +91,7 @@ trait SQLiteProfile extends JdbcProfile {
- JdbcCapabilities.booleanMetaData
- JdbcCapabilities.supportsByte
- JdbcCapabilities.distinguishesIntTypes
- JdbcCapabilities.forUpdate
)

class ModelBuilder(mTables: Seq[MTable], ignoreInvalidDefaults: Boolean)(implicit ec: ExecutionContext) extends JdbcModelBuilder(mTables, ignoreInvalidDefaults) {
Expand Down
5 changes: 5 additions & 0 deletions slick/src/main/scala/slick/lifted/Query.scala
Expand Up @@ -155,6 +155,11 @@ sealed abstract class Query[+E, U, C[_]] extends QueryBase[C[U]] { self =>
new WrappingQuery[(G, Query[P, U, Seq]), (T, Query[P, U, Seq]), C](group, key.zip(value))
}

/** Specify part of a select statement for update and marked for row level locking */
def forUpdate: Query[E, U, C] = {
val generator = new AnonSymbol
new WrappingQuery[E, U, C](ForUpdate(generator, toNode), shaped)
}
def encodeRef(path: Node): Query[E, U, C] = new Query[E, U, C] {
val shaped = self.shaped.encodeRef(path)
def toNode = path
Expand Down

0 comments on commit 919609f

Please sign in to comment.