Permalink
Browse files

SelectForUpdate support

  • Loading branch information...
1 parent a63c2ba commit 919609fe698836306e90eead7250180fbc61487f @smootoo smootoo committed Mar 1, 2016
@@ -45,6 +45,7 @@ testkit {
${testPackage}.TemplateTest
${testPackage}.TransactionTest
${testPackage}.UnionTest
+ ${testPackage}.ForUpdateTest
]
}
@@ -0,0 +1,98 @@
+package com.typesafe.slick.testkit.tests
+
+import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue, TimeUnit, ThreadPoolExecutor}
+
+import com.typesafe.slick.testkit.util.{TestkitConfig, AsyncTest, JdbcTestDB}
+import org.junit.Assert
+import slick.dbio.DBIOAction
+import slick.jdbc.{SQLServerProfile, TransactionIsolation}
+import slick.util.Logging
+
+import scala.concurrent.duration.Duration
+import scala.concurrent.{Await, ExecutionContext, Future}
+import scala.util.Failure
+
+
+class ForUpdateTest extends AsyncTest[JdbcTestDB] with Logging {
+ import tdb.profile.api._
+
+ val tableName = "test_forupdate"
+ class T(tag: Tag) extends Table[(Int, Option[String])](tag, tableName) {
+ def id = column[Int]("id", O.PrimaryKey)
+ def data = column[Option[String]]("data")
+ def * = (id, data)
+ }
+ val ts = TableQuery[T]
+ def testForUpdate: DBIO[Unit] = {
+ ifCap(jcap.forUpdate) {
+ val exe = new ThreadPoolExecutor(2, 2, 1L, TimeUnit.SECONDS, new LinkedBlockingQueue[Runnable]())
+ @volatile var success = true
+ val childStartLatch = new CountDownLatch(1)
+ val thread1Latch = new CountDownLatch(1)
+ val thread2Latch = new CountDownLatch(1)
+ val rowSelect: (Int) => Query[T, (Int, Option[String]), Seq] = (i: Int) => ts.filter(_.id === i)
+ def runAsyncCommands[T](commands: DBIO[Unit], latch: CountDownLatch): DBIO[Unit] = {
+ exe.execute(new Runnable {
+ override def run(): Unit = Await.result({
+ childStartLatch.await()
+ val f = db.run {
+ seq(GetTransactionality.map(_._1 shouldBe 0), // make sure not in transaction yet
+ commands.transactionally)
+ }
+ f.onComplete { t => {
+ latch.countDown()
+ t match {
+ case Failure(e) => success = false
+ case _ =>
+ }
+ }}
+ f
+ }, TestkitConfig.asyncTimeout)
+ })
+ DBIOAction.successful(()) // dummy action to add into pipeline
+ }
+ seq(
+ ifCap(tcap.selectForUpdateRowLocking) { // if database is capable of executing a row locking test
+ for {
+ _ <- ts.schema.create
+ _ <- ts ++= Seq((1, None), (2, None))
+ // start txn for main thread
+ _ <- (for {
+ _ <- GetTransactionality.map(_._1 shouldBe 1) // check in main transaction
+ // locking read on row.id 1
+ r1 <- rowSelect(1).forUpdate.result
+ _ = r1 shouldBe Seq((1, None))
+ _ <- runAsyncCommands(seq(
+ // this read is free to continue
+ rowSelect(2).forUpdate.result.map(x => x shouldBe Seq((2, None))),
+ rowSelect(2).map(_.data).update(Some("Thread 1 update"))
+ ), thread1Latch)
+ _ <- runAsyncCommands(seq(
+ // this read blocks on main thread txn,so check the main update happened first once running
+ rowSelect(1).forUpdate.result.map(_ shouldBe Seq((1, Some("Main thread update")))),
+ rowSelect(1).map(_.data).update(Some("Thread 2 update"))
+ ), thread2Latch)
+ _ = childStartLatch.countDown() // start child threads
+ _ = thread1Latch.await() // wait for thread 1 to finish
+ _ <- rowSelect(1).map(_.data).update(Some("Main thread update"))
+ } yield ()).transactionally
+ _ = thread2Latch.await()
+ // Thread 2 update should have overwritten main thread update
+ _ <- ts.result.map(_.toSet shouldBe Set((1, Some("Thread 2 update")), (2, Some("Thread 1 update"))))
+ _ <- ts.schema.drop
+ _ = exe.shutdown()
+ // Fail the test if there were failures in the child threads
+ _ = Assert.assertTrue(success)
+ } yield ()
+ },
+ ifNotCap(tcap.selectForUpdateRowLocking) { // a simple test to assert the syntax is valid
+ for {
+ _ <- ts.schema.create
+ _ <- ts ++= Seq((1, None), (2, None))
+ r1 <- rowSelect(1).forUpdate.result
+ _ = r1 shouldBe Seq((1, None))
+ } yield ()
+ })
+ }
+ }
+}
@@ -183,6 +183,9 @@ object StandardTestDBs {
import profile.api.actionBasedSQLInterpolation
val defaultSchema = config.getString("defaultSchema")
+ // sqlserver has valid "select for update" syntax, but in testing on Appveyor, the test hangs due to lock escalation
+ // so exclude from explicit ForUpdate testing
+ override def capabilities = super.capabilities - TestDB.capabilities.selectForUpdateRowLocking
override def localTables(implicit ec: ExecutionContext): DBIO[Vector[String]] =
ResultSetAction[(String,String,String, String)](_.conn.getMetaData().getTables(testDB, defaultSchema, null, null)).map { ts =>
@@ -292,6 +295,9 @@ object DerbyDB {
abstract class HsqlDB(confName: String) extends InternalJdbcTestDB(confName) {
val profile = HsqldbProfile
val jdbcDriver = "org.hsqldb.jdbcDriver"
+ // Hsqldb has valid "select for update" syntax, but in testing, it either takes a whole table lock or no exclusive
+ // lock at all, so exclude from ForUpdate testing
+ override def capabilities = super.capabilities - TestDB.capabilities.selectForUpdateRowLocking
override def localTables(implicit ec: ExecutionContext): DBIO[Vector[String]] =
ResultSetAction[(String,String,String, String)](_.conn.getMetaData().getTables(null, "PUBLIC", null, null)).map { ts =>
ts.map(_._3).sorted
@@ -37,8 +37,11 @@ object TestDB {
val jdbcMetaGetIndexInfo = new Capability("test.jdbcMetaGetIndexInfo")
/** Supports all tested transaction isolation levels */
val transactionIsolation = new Capability("test.transactionIsolation")
+ /** Supports select for update row locking */
+ val selectForUpdateRowLocking = new Capability("test.selectForUpdateRowLocking")
- val all = Set(plainSql, jdbcMeta, jdbcMetaGetClientInfoProperties, jdbcMetaGetFunctions, jdbcMetaGetIndexInfo, transactionIsolation)
+ val all = Set(plainSql, jdbcMeta, jdbcMetaGetClientInfoProperties, jdbcMetaGetFunctions, jdbcMetaGetIndexInfo,
+ transactionIsolation, selectForUpdateRowLocking)
}
/** Copy a file, expanding it if the source name ends with .gz */
@@ -9,7 +9,9 @@ final case class Comprehension(sym: TermSymbol, from: Node, select: Node, where:
groupBy: Option[Node] = None, orderBy: ConstArray[(Node, Ordering)] = ConstArray.empty,
having: Option[Node] = None,
distinct: Option[Node] = None,
- fetch: Option[Node] = None, offset: Option[Node] = None) extends DefNode {
+ fetch: Option[Node] = None,
+ offset: Option[Node] = None,
+ forUpdate: Boolean = false) extends DefNode {
type Self = Comprehension
lazy val children = (ConstArray.newBuilder() + from + select ++ where ++ groupBy ++ orderBy.map(_._1) ++ having ++ distinct ++ fetch ++ offset).result
override def childNames =
@@ -377,6 +377,14 @@ final case class GroupBy(fromGen: TermSymbol, from: Node, by: Node, identity: Ty
}
}
+/** A .forUpdate call */
+final case class ForUpdate(generator: TermSymbol, from: Node) extends ComplexFilteredQuery {
+ type Self = ForUpdate
+ lazy val children = ConstArray(from)
+ protected[this] def rebuild(ch: ConstArray[Node]) = copy(from = ch(0))
+ protected[this] def rebuildWithSymbols(gen: ConstArray[TermSymbol]) = copy(generator = gen(0))
+}
+
/** A .take call. */
final case class Take(from: Node, count: Node) extends SimpleFilteredQuery with BinaryNode {
type Self = Take
@@ -81,6 +81,13 @@ class MergeToComprehensions extends Phase {
logger.debug("Merged SortBy into Comprehension:", c2)
(c2, replacements1)
+ case ForUpdate(s1, f1) =>
+ val (c1, replacements1) = mergeSortBy(f1, true)
+ logger.debug("Merging ForUpdate into Comprehension:", Ellipsis(n, List(0)))
+ val c2 = c1.copy(forUpdate = true) :@ c1.nodeType
+ logger.debug("Merged ForUpdate into Comprehension:", c2)
+ (c2, replacements1)
+
case Distinct(s1, f1, o1) =>
val (c1, replacements1) = mergeSortBy(f1, true)
val (c1a, replacements1a) =
@@ -20,7 +20,7 @@ class RemoveFieldNames(val alwaysKeepSubqueryNames: Boolean = false) extends Pha
val refTSyms = n.collect[TypeSymbol] {
case Select(_ :@ NominalType(s, _), _) => s
case Union(_, _ :@ CollectionType(_, NominalType(s, _)), _) => s
- case Comprehension(_, _ :@ CollectionType(_, NominalType(s, _)), _, _, _, _, _, _, _, _) if alwaysKeepSubqueryNames => s
+ case Comprehension(_, _ :@ CollectionType(_, NominalType(s, _)), _, _, _, _, _, _, _, _, _) if alwaysKeepSubqueryNames => s
}.toSet
val allTSyms = n.collect[TypeSymbol] { case p: Pure => p.identity }.toSet
val unrefTSyms = allTSyms -- refTSyms
@@ -36,7 +36,7 @@ class RewriteBooleans extends Phase {
case Apply(sym, ch) :@ tpe if isBooleanLike(tpe) =>
toFake(Apply(sym, ch)(n.nodeType).infer())
// Where clauses, join conditions and case clauses need real boolean predicates
- case n @ Comprehension(_, _, _, where, _, _, having, _, _, _) =>
+ case n @ Comprehension(_, _, _, where, _, _, having, _, _, _, _) =>
n.copy(where = where.map(toReal), having = having.map(toReal)) :@ n.nodeType
case n @ Join(_, _, _, _, _, on) =>
n.copy(on = toReal(on)) :@ n.nodeType
@@ -14,9 +14,9 @@ class SpecializeParameters extends Phase {
state.map(ClientSideOp.mapServerSide(_, keepType = true)(transformServerSide))
def transformServerSide(n: Node): Node = {
- val cs = n.collect { case c @ Comprehension(_, _, _, _, _, _, _, _, Some(_: QueryParameter), _) => c }
+ val cs = n.collect { case c @ Comprehension(_, _, _, _, _, _, _, _, Some(_: QueryParameter), _, _) => c }
logger.debug("Affected fetch clauses in: "+cs.mkString(", "))
- cs.foldLeft(n) { case (n, c @ Comprehension(_, _, _, _, _, _, _, _, Some(fetch: QueryParameter), _)) =>
+ cs.foldLeft(n) { case (n, c @ Comprehension(_, _, _, _, _, _, _, _, Some(fetch: QueryParameter), _, _)) =>
val compiledFetchParam = QueryParameter(fetch.extractor, ScalaBaseType.longType)
val guarded = n.replace({ case c2: Comprehension if c2 == c => c2.copy(fetch = Some(LiteralNode(0L))) }, keepType = true)
val fallback = n.replace({ case c2: Comprehension if c2 == c => c2.copy(fetch = Some(compiledFetchParam)) }, keepType = true)
@@ -108,6 +108,13 @@ trait DB2Profile extends JdbcProfile {
expr(n)
if(o.direction.desc) b += " desc"
}
+
+ override protected def buildForUpdateClause(forUpdate: Boolean) = {
+ super.buildForUpdateClause(forUpdate)
+ if(forUpdate) {
+ b" with RS "
+ }
+ }
}
class TableDDLBuilder(table: Table[_]) extends super.TableDDLBuilder(table) {
@@ -125,6 +125,13 @@ trait DerbyProfile extends JdbcProfile {
override protected val supportsLiteralGroupBy = true
override protected val quotedJdbcFns = Some(Vector(Library.User))
+ override protected def buildForUpdateClause(forUpdate: Boolean) = {
+ super.buildForUpdateClause(forUpdate)
+ if (forUpdate) {
+ b" with RS "
+ }
+ }
+
override def expr(c: Node, skipParens: Boolean = false): Unit = c match {
case Library.Cast(ch @ _*) =>
/* Work around DERBY-2072 by casting numeric values first to CHAR and
@@ -29,10 +29,12 @@ object JdbcCapabilities {
val distinguishesIntTypes = Capability("jdbc.distinguishesIntTypes")
/** Has a datatype directly corresponding to Scala Byte */
val supportsByte = Capability("jdbc.supportsByte")
+ /** Supports FOR UPDATE row level locking */
+ val forUpdate = Capability("jdbc.forUpdate")
/** Supports all JdbcProfile features which do not have separate capability values */
val other = Capability("jdbc.other")
/** All JDBC capabilities */
- val all = Set(other, createModel, forceInsert, insertOrUpdate, mutable, returnInsertKey, defaultValueMetaData, booleanMetaData, nullableNoDefault, distinguishesIntTypes, supportsByte, returnInsertOther)
+ val all = Set(other, createModel, forceInsert, insertOrUpdate, mutable, returnInsertKey, defaultValueMetaData, booleanMetaData, nullableNoDefault, distinguishesIntTypes, supportsByte, returnInsertOther, forUpdate)
}
@@ -148,6 +148,7 @@ trait JdbcStatementBuilderComponent { self: JdbcProfile =>
buildHavingClause(c.having)
buildOrderByClause(c.orderBy)
if(!limit0) buildFetchOffsetClause(c.fetch, c.offset)
+ buildForUpdateClause(c.forUpdate)
currentUniqueFrom = oldUniqueFrom
}
@@ -248,6 +249,12 @@ trait JdbcStatementBuilderComponent { self: JdbcProfile =>
}
}
+ protected def buildForUpdateClause(forUpdate: Boolean) = building(OtherPart) {
+ if(forUpdate) {
+ b"\nfor update "
+ }
+ }
+
protected def buildSelectPart(n: Node): Unit = n match {
case c: Comprehension =>
b"\["
@@ -437,7 +444,7 @@ trait JdbcStatementBuilderComponent { self: JdbcProfile =>
def buildUpdate: SQLBuilder.Result = {
val (gen, from, where, select) = tree match {
- case Comprehension(sym, from: TableNode, Pure(select, _), where, None, _, None, None, None, None) => select match {
+ case Comprehension(sym, from: TableNode, Pure(select, _), where, None, _, None, None, None, None, false) => select match {
case f @ Select(Ref(struct), _) if struct == sym => (sym, from, where, ConstArray(f.field))
case ProductNode(ch) if ch.forall{ case Select(Ref(struct), _) if struct == sym => true; case _ => false} =>
(sym, from, where, ch.map{ case Select(Ref(_), field) => field })
@@ -461,9 +468,9 @@ trait JdbcStatementBuilderComponent { self: JdbcProfile =>
def fail(msg: String) =
throw new SlickException("Invalid query for DELETE statement: " + msg)
val (gen, from, where) = tree match {
- case Comprehension(sym, from, Pure(select, _), where, _, _, None, distinct, fetch, offset) =>
- if(fetch.isDefined || offset.isDefined || distinct.isDefined)
- fail(".take, .drop and .distinct are not supported for deleting")
+ case Comprehension(sym, from, Pure(select, _), where, _, _, None, distinct, fetch, offset, forUpdate) =>
+ if(fetch.isDefined || offset.isDefined || distinct.isDefined || forUpdate)
+ fail(".take, .drop .forUpdate and .distinct are not supported for deleting")
from match {
case from: TableNode => (sym, from, where)
case from => fail("A single source table is required, found: "+from)
@@ -153,6 +153,19 @@ trait SQLServerProfile extends JdbcProfile {
if(o.direction.desc) b" desc"
}
+ override protected def buildFromClause(from: Seq[(TermSymbol, Node)]) = {
+ super.buildFromClause(from)
+ tree match {
+ // SQL Server "select for update" syntax
+ case c: Comprehension => if(c.forUpdate) b" with (updlock,rowlock) "
+ case _ =>
+ }
+ }
+
+ override protected def buildForUpdateClause(forUpdate: Boolean) = {
+ // SQLSever doesn't have "select for update" syntax, so use with (updlock,rowlock) in from clause
+ }
+
override def expr(n: Node, skipParens: Boolean = false): Unit = n match {
// Cast bind variables of type TIME to TIME (otherwise they're treated as TIMESTAMP)
case c @ LiteralNode(v) if c.volatileHint && jdbcTypeFor(c.nodeType) == columnTypes.timeJdbcType =>
@@ -91,6 +91,7 @@ trait SQLiteProfile extends JdbcProfile {
- JdbcCapabilities.booleanMetaData
- JdbcCapabilities.supportsByte
- JdbcCapabilities.distinguishesIntTypes
+ - JdbcCapabilities.forUpdate
)
class ModelBuilder(mTables: Seq[MTable], ignoreInvalidDefaults: Boolean)(implicit ec: ExecutionContext) extends JdbcModelBuilder(mTables, ignoreInvalidDefaults) {
@@ -155,6 +155,11 @@ sealed abstract class Query[+E, U, C[_]] extends QueryBase[C[U]] { self =>
new WrappingQuery[(G, Query[P, U, Seq]), (T, Query[P, U, Seq]), C](group, key.zip(value))
}
+ /** Specify part of a select statement for update and marked for row level locking */
+ def forUpdate: Query[E, U, C] = {
+ val generator = new AnonSymbol
+ new WrappingQuery[E, U, C](ForUpdate(generator, toNode), shaped)
+ }
def encodeRef(path: Node): Query[E, U, C] = new Query[E, U, C] {
val shaped = self.shaped.encodeRef(path)
def toNode = path

0 comments on commit 919609f

Please sign in to comment.