Permalink
Browse files

protects acquireSession call in a try / finally (#1877)

* protects acquireSession call in a try / finally

* removed some minor warnings

* added ctx.sync = 0

* add test that proves that thread is being liberated
  • Loading branch information...
renatocaval authored and hvesalai committed Mar 23, 2018
1 parent fc5d3ed commit d713b9fff117e4fc962cc463bdd1cbfdc51c73f8
@@ -0,0 +1,111 @@
package slick.test.jdbc
import java.sql.Connection
import java.util.UUID
import com.typesafe.config.ConfigFactory
import org.junit.Test
import org.junit.Assert._
import slick.jdbc.H2Profile.api._
import slick.jdbc.{JdbcBackend, JdbcDataSource}
import slick.util.ClassLoaderUtil
import scala.concurrent.Await
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.util.Failure
class ManagedQueueTest {
@Test
def testBrokenConnectionWithStreamingAction() = {
val config =
ConfigFactory.parseString(
"""
|dataSource {
| profile = "slick.jdbc.H2Profile$"
| db {
| connectionPool = disabled
| dataSourceClass = "slick.jdbc.DriverDataSource"
| properties = {
| driver = "org.h2.Driver"
| url = "jdbc:h2:mem:test;DB_CLOSE_DELAY=-1"
| }
| }
|}
|""".stripMargin)
val dataSource = new JdbcDataSourceWrap(
JdbcDataSource.forConfig(
config.getConfig("dataSource.db"),
driver = null,
"test",
ClassLoaderUtil.defaultClassLoader)
)
// only one thread and one connection available
val asyncExecutor = AsyncExecutor("test", 1, 1, 1, 1)
class T(tag: Tag) extends Table[Int](tag, "TableA") {
def a = column[Int]("a")
def * = a
}
val ts = TableQuery[T]
val db = JdbcBackend.Database.forSource(dataSource, asyncExecutor)
try {
val values = Seq(2, 3, 1, 5, 4)
// init schema and insert some data
val initAction =
for {
_ <- ts.schema.create
_ <- ts ++= values
} yield ()
Await.ready(db.run(initAction.transactionally), 5.seconds)
// run stream in fail mode
// before the fix for (https://github.com/slick/slick/issues/1875)
// this would consume the single thread we have in AsyncExecutor
dataSource.failMode()
val streamResult = db.stream(ts.result).foreach(println)
Await
.ready(streamResult, 3.seconds)
.onComplete {
case Failure(ex) =>
assertEquals("DB is not available!", ex.getMessage)
case _ => fail("This was expect to fail")
}
// before the fix for (https://github.com/slick/slick/issues/1875)
// this would have hung forever
// the fix brings the managed queue back to not-paused state
// which allows this next call to succeed
dataSource.successMode()
val seq = Await.result(db.run(ts.result), 3.seconds)
assertEquals(values, seq)
} finally {
db.close
}
}
/* JdbcDataSource wrap that will help us to simulate connection failures */
class JdbcDataSourceWrap(underlying: JdbcDataSource) extends JdbcDataSource {
private var _failMode = false
def failMode() = _failMode = true
def successMode() = _failMode = false
override def createConnection(): Connection = {
if (_failMode) throw new RuntimeException("DB is not available!")
else underlying.createConnection()
}
override def close(): Unit = underlying.close()
override val maxConnections: Option[Int] = underlying.maxConnections
}
}
@@ -296,68 +296,86 @@ trait BasicBackend { self =>
}
/** Stream a part of the results of a `SynchronousDatabaseAction` on this database. */
protected[BasicBackend] def scheduleSynchronousStreaming(a: SynchronousDatabaseAction[_, _ <: NoStream, This, _ <: Effect], ctx: StreamingContext, continuation: Boolean)(initialState: a.StreamState): Unit = try {
ctx.getEC(synchronousExecutionContext).prepare.execute(new AsyncExecutor.PrioritizedRunnable {
private[this] def str(l: Long) = if(l != Long.MaxValue) l else if(GlobalConfig.unicodeDump) "\u221E" else "oo"
protected[BasicBackend] def scheduleSynchronousStreaming(a: SynchronousDatabaseAction[_, _ <: NoStream, This, _ <: Effect], ctx: StreamingContext, continuation: Boolean)(initialState: a.StreamState): Unit =
try {
ctx.getEC(synchronousExecutionContext).prepare.execute(new AsyncExecutor.PrioritizedRunnable {
private[this] def str(l: Long) = if(l != Long.MaxValue) l else if(GlobalConfig.unicodeDump) "\u221E" else "oo"
def priority = {
ctx.readSync
ctx.priority(continuation)
}
def priority = {
ctx.readSync
ctx.priority(continuation)
}
def run: Unit = try {
val debug = streamLogger.isDebugEnabled
var state = initialState
ctx.readSync
if(state eq null) acquireSession(ctx)
var demand = ctx.demandBatch
var realDemand = if(demand < 0) demand - Long.MinValue else demand
do {
def run(): Unit =
try {
if(debug)
streamLogger.debug((if(state eq null) "Starting initial" else "Restarting ") + " streaming action, realDemand = " + str(realDemand))
if(ctx.cancelled) {
if(ctx.deferredError ne null) throw ctx.deferredError
if(state ne null) { // streaming cancelled before finishing
val oldState = state
state = null
a.cancelStream(ctx, oldState)
val debug = streamLogger.isDebugEnabled
var state = initialState
ctx.readSync
if(state eq null) acquireSession(ctx)
var demand = ctx.demandBatch
var realDemand = if(demand < 0) demand - Long.MinValue else demand
do {
try {
if(debug)
streamLogger.debug((if(state eq null) "Starting initial" else "Restarting ") + " streaming action, realDemand = " + str(realDemand))
if(ctx.cancelled) {
if(ctx.deferredError ne null) throw ctx.deferredError
if(state ne null) { // streaming cancelled before finishing
val oldState = state
state = null
a.cancelStream(ctx, oldState)
}
} else if(realDemand > 0 || (state eq null)) {
val oldState = state
state = null
state = a.emitStream(ctx, realDemand, oldState)
}
if(state eq null) { // streaming finished and cleaned up
releaseSession(ctx, true)
ctx.streamingResultPromise.trySuccess(null)
}
} catch { case NonFatal(ex) =>
if(state ne null) try a.cancelStream(ctx, state) catch ignoreFollowOnError
releaseSession(ctx, true)
throw ex
} finally {
ctx.streamState = state
if (!ctx.isPinned && ctx.priority(continuation) != WithConnection) connectionReleased = true
ctx.sync = 0
}
} else if((realDemand > 0 || (state eq null))) {
val oldState = state
state = null
state = a.emitStream(ctx, realDemand, oldState)
}
if(state eq null) { // streaming finished and cleaned up
releaseSession(ctx, true)
ctx.streamingResultPromise.trySuccess(null)
if(debug) {
if(state eq null) streamLogger.debug(s"Sent up to ${str(realDemand)} elements - Stream " + (if(ctx.cancelled) "cancelled" else "completely delivered"))
else streamLogger.debug(s"Sent ${str(realDemand)} elements, more available - Performing atomic state transition")
}
demand = ctx.delivered(demand)
realDemand = if(demand < 0) demand - Long.MinValue else demand
} while ((state ne null) && realDemand > 0)
if(debug) {
if(state ne null) streamLogger.debug("Suspending streaming action with continuation (more data available)")
else streamLogger.debug("Finished streaming action")
}
} catch { case NonFatal(ex) =>
if(state ne null) try a.cancelStream(ctx, state) catch ignoreFollowOnError
releaseSession(ctx, true)
throw ex
} catch {
case NonFatal(ex) => ctx.streamingResultPromise.tryFailure(ex)
} finally {
ctx.streamState = state
if (!ctx.isPinned && ctx.priority(continuation) != WithConnection) connectionReleased = true
ctx.sync = 0
}
if(debug) {
if(state eq null) streamLogger.debug(s"Sent up to ${str(realDemand)} elements - Stream " + (if(ctx.cancelled) "cancelled" else "completely delivered"))
else streamLogger.debug(s"Sent ${str(realDemand)} elements, more available - Performing atomic state transition")
}
demand = ctx.delivered(demand)
realDemand = if(demand < 0) demand - Long.MinValue else demand
} while ((state ne null) && realDemand > 0)
if(debug) {
if(state ne null) streamLogger.debug("Suspending streaming action with continuation (more data available)")
else streamLogger.debug("Finished streaming action")
}
} catch { case NonFatal(ex) => ctx.streamingResultPromise.tryFailure(ex) }
})
} catch { case NonFatal(ex) =>
streamLogger.warn("Error scheduling synchronous streaming", ex)
throw ex
}
})
} catch { case NonFatal(ex) =>
streamLogger.warn("Error scheduling synchronous streaming", ex)
throw ex
}
@@ -83,10 +83,10 @@ object AsyncExecutor extends Logging {
new LinkedBlockingQueue[Runnable]
case n =>
// NOTE: The current implementation of ManagedArrayBlockingQueue is flawned. It makes the assumption that all
// NOTE: The current implementation of ManagedArrayBlockingQueue is flawed. It makes the assumption that all
// tasks go through the queue (which is responsible for scheduling high-priority tasks first). However, that
// assumption is wrong since the ThreadPoolExecutor bypasses the queue when it creates new threads. This
// happens when ever it creates a new thread to run a task, i.e. when minThreads < maxThreads and the number
// happens whenever it creates a new thread to run a task, i.e. when minThreads < maxThreads and the number
// of existing threads is < maxThreads.
//
// The only way to prevent problems is to have minThreads == maxThreads when using the
@@ -55,7 +55,7 @@ class ManagedArrayBlockingQueue[E >: Null <: PrioritizedRunnable](maximumInUse:
inUseCount -= 1
if (inUseCount == maximumInUse - 1) {
logger.debug("resuming")
paused = false
paused = false
if (counts > 0) notEmpty.signalAll()
}
}
@@ -96,7 +96,6 @@ class ManagedArrayBlockingQueue[E >: Null <: PrioritizedRunnable](maximumInUse:
nanos = itemQueueNotFull.awaitNanos(nanos)
}
insert(e)
return true
}
}
@@ -190,7 +189,7 @@ class ManagedArrayBlockingQueue[E >: Null <: PrioritizedRunnable](maximumInUse:
(highPrioItemQueue.iterator.asScala ++ itemQueue.iterator.asScala).toList.toIterator
}
return new util.Iterator[E] {
new util.Iterator[E] {
override def hasNext: Boolean = items.hasNext
override def next: E = items.next
override def remove(): Unit = throw new UnsupportedOperationException

0 comments on commit d713b9f

Please sign in to comment.