Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

fail fast when the connection pool is under-populated, corresponding …

…config, unit and integration tests
  • Loading branch information...
commit f9f42423ea988b56fcf6ac9fcc93da3704e6c6d4 1 parent 83ac979
yswu authored
View
23 querulous-core/src/main/scala/com/twitter/querulous/config/Database.scala
@@ -4,14 +4,21 @@ import com.twitter.querulous._
import com.twitter.util.Duration
import com.twitter.conversions.time._
import database._
+import util.Random
+trait FailFastPolicyConfig {
+ def highWaterMark: Double
+ def lowWaterMark: Double
+ def openTimeout: Duration
+ def rng: Option[Random]
+}
trait PoolingDatabase {
def apply(): DatabaseFactory
}
-trait ServiceNameTagged {
- def apply(serviceName: Option[String]): DatabaseFactory
+trait ServiceNameAndFailFastPolicy {
+ def apply(serviceName: Option[String], ffp: Option[FailFastPolicyConfig]): DatabaseFactory
}
class ApachePoolingDatabase extends PoolingDatabase {
@@ -28,19 +35,19 @@ class ApachePoolingDatabase extends PoolingDatabase {
}
}
-class ThrottledPoolingDatabase extends PoolingDatabase with ServiceNameTagged {
+class ThrottledPoolingDatabase extends PoolingDatabase with ServiceNameAndFailFastPolicy {
var size: Int = 10
var openTimeout: Duration = 50.millis
var repopulateInterval: Duration = 500.millis
var idleTimeout: Duration = 1.minute
def apply() = {
- apply(None)
+ apply(None, None)
}
- def apply(serviceName: Option[String]) = {
+ def apply(serviceName: Option[String], ffp: Option[FailFastPolicyConfig]) = {
new ThrottledPoolingDatabaseFactory(
- serviceName, size, openTimeout, idleTimeout, repopulateInterval, Map.empty)
+ serviceName, size, openTimeout, idleTimeout, repopulateInterval, Map.empty, ffp)
}
}
@@ -73,6 +80,8 @@ class Database {
var memoize: Boolean = true
var serviceName: Option[String] = None
def serviceName_=(s: String) { serviceName = Some(s) }
+ var failFastPolicyConfig: Option[FailFastPolicyConfig] = None
+ def failFastPolicyConfig_=(ffp: FailFastPolicyConfig) { failFastPolicyConfig = Some(ffp) }
def apply(stats: StatsCollector): DatabaseFactory = apply(stats, None)
@@ -80,7 +89,7 @@ class Database {
def apply(stats: StatsCollector, statsFactory: Option[DatabaseFactory => DatabaseFactory]): DatabaseFactory = {
var factory = pool.map{ _ match {
- case p: ServiceNameTagged => p(serviceName)
+ case p: ServiceNameAndFailFastPolicy => p(serviceName, failFastPolicyConfig)
case p: PoolingDatabase => p()
}}.getOrElse(new SingleConnectionDatabaseFactory)
View
114 querulous-core/src/main/scala/com/twitter/querulous/database/ThrottledPoolingDatabase.scala
@@ -4,14 +4,84 @@ import java.util.concurrent.{TimeUnit, LinkedBlockingQueue}
import java.sql.{SQLException, DriverManager, Connection}
import org.apache.commons.dbcp.{PoolingDataSource, DelegatingConnection}
import org.apache.commons.pool.{PoolableObjectFactory, ObjectPool}
+import com.twitter.querulous.config.FailFastPolicyConfig
import com.twitter.util.Duration
import com.twitter.util.Time
import scala.annotation.tailrec
-import java.lang.Thread
import java.util.concurrent.atomic.AtomicInteger
+import java.security.InvalidParameterException
+import util.Random
+import java.lang.{UnsupportedOperationException, Thread}
+
+class FailToAcquireConnectionException extends SQLException
+class PoolTimeoutException extends FailToAcquireConnectionException
+class PoolFailFastException extends FailToAcquireConnectionException
+class PoolEmptyException extends PoolFailFastException
+
+/**
+ * determine whether to fail fast when trying to check out a connection from pool
+ */
+trait FailFastPolicy {
+ /**
+ * This method throws PoolFailFastException when it decides to fail fast given the current state
+ * of the underlying database, or it could throw PoolTimeoutException when failing to acquire
+ * a connection within specified time frame
+ *
+ * @param db the database from which we are going to open connections
+ */
+ @throws(classOf[FailToAcquireConnectionException])
+ def failFast(pool: ObjectPool)(f: Duration => Connection): Connection
+}
-class PoolTimeoutException extends SQLException
-class PoolEmptyException extends SQLException
+/**
+ * This policy behaves in the way specified as follows:
+ * When the number of connections in the pool is below the highWaterMark, start to use the timeout
+ * passed in when waiting for a connection; when it is below the lowWaterMark, start to fail
+ * immediately proportional to the number of connections available in the pool, with 100% failure
+ * rate when the pool is empty
+ */
+class FailFastBasedOnNumConnsPolicy(val highWaterMark: Double, val lowWaterMark: Double,
+ val openTimeout: Duration, val rng: Random) extends FailFastPolicy {
+ if (highWaterMark < lowWaterMark || highWaterMark > 1 || lowWaterMark < 0) {
+ throw new InvalidParameterException("invalid water mark")
+ }
+
+ @throws(classOf[FailToAcquireConnectionException])
+ def failFast(pool: ObjectPool)(f: Duration => Connection) = {
+ pool match {
+ case p: ThrottledPool => {
+ val numConn = p.getTotal()
+ if (numConn == 0) {
+ throw new PoolEmptyException
+ } else if (numConn < p.size * lowWaterMark) {
+ if(numConn < rng.nextDouble() * p.size * lowWaterMark) {
+ throw new PoolFailFastException
+ } else {
+ // should still try to do aggressive timeout at least
+ f(openTimeout)
+ }
+ } else if (numConn < p.size * highWaterMark) {
+ f(openTimeout)
+ } else {
+ f(p.timeout)
+ }
+ }
+ case _ => throw new UnsupportedOperationException("Only support ThrottledPoolingDatabase")
+ }
+ }
+}
+
+object FailFastBasedOnNumConnsPolicy {
+ def apply(openTimeout: Duration): FailFastBasedOnNumConnsPolicy = {
+ apply(0, 0, openTimeout, Some(new Random(System.currentTimeMillis())))
+ }
+
+ def apply(highWaterMark: Double, lowWaterMark: Double, openTimeout: Duration,
+ rng: Option[Random]): FailFastBasedOnNumConnsPolicy = {
+ new FailFastBasedOnNumConnsPolicy(highWaterMark, lowWaterMark, openTimeout,
+ rng.getOrElse(new Random(System.currentTimeMillis())))
+ }
+}
class PooledConnection(c: Connection, p: ObjectPool) extends DelegatingConnection(c) {
private var pool: Option[ObjectPool] = Some(p)
@@ -43,12 +113,16 @@ class PooledConnection(c: Connection, p: ObjectPool) extends DelegatingConnectio
}
}
-class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration,
- idleTimeout: Duration) extends ObjectPool {
+case class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration,
+ idleTimeout: Duration, failFastPolicy: FailFastPolicy) extends ObjectPool {
private val pool = new LinkedBlockingQueue[(Connection, Time)]()
private val currentSize = new AtomicInteger(0)
private val numWaiters = new AtomicInteger(0)
+ def this(factory: () => Connection, size: Int, timeout: Duration, idleTimeout: Duration) = {
+ this(factory, size, timeout, idleTimeout, FailFastBasedOnNumConnsPolicy(timeout))
+ }
+
for (i <- (0.until(size))) addObject()
def addObject() {
@@ -69,17 +143,17 @@ class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration,
final def borrowObject(): Connection = {
numWaiters.incrementAndGet()
try {
- borrowObjectInternal()
+ failFastPolicy.failFast(this)(borrowObjectInternal)
} finally {
numWaiters.decrementAndGet()
}
}
- @tailrec private def borrowObjectInternal(): Connection = {
+ @tailrec private def borrowObjectInternal(openTimeout: Duration): Connection = {
// short circuit if the pool is empty
if (getTotal() == 0) throw new PoolEmptyException
- val pair = pool.poll(timeout.inMillis, TimeUnit.MILLISECONDS)
+ val pair = pool.poll(openTimeout.inMillis, TimeUnit.MILLISECONDS)
if (pair == null) throw new PoolTimeoutException
val (connection, lastUse) = pair
@@ -88,7 +162,7 @@ class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration,
try { connection.close() } catch { case _: SQLException => }
// note: dbcp handles object invalidation here.
addObjectIfEmpty()
- borrowObjectInternal()
+ borrowObjectInternal(openTimeout)
} else {
connection
}
@@ -170,11 +244,19 @@ class ThrottledPoolingDatabaseFactory(
openTimeout: Duration,
idleTimeout: Duration,
repopulateInterval: Duration,
- defaultUrlOptions: Map[String, String]) extends DatabaseFactory {
+ defaultUrlOptions: Map[String, String],
+ failFastPolicyConfig: Option[FailFastPolicyConfig]) extends DatabaseFactory {
+
+ // the default is the one with both highWaterMark and lowWaterMark of 0
+ // in this case, PoolEmptyException will be thrown when the number of connections in the pool
+ // is zero; otherwise, it will behave the same way as if this policy is not applied
+ private val failFastPolicy = failFastPolicyConfig map {pc =>
+ FailFastBasedOnNumConnsPolicy(pc.highWaterMark, pc.lowWaterMark, pc.openTimeout, pc.rng)
+ } getOrElse (FailFastBasedOnNumConnsPolicy(openTimeout))
def this(size: Int, openTimeout: Duration, idleTimeout: Duration, repopulateInterval: Duration,
defaultUrlOptions: Map[String, String]) = {
- this(None, size, openTimeout, idleTimeout, repopulateInterval, defaultUrlOptions)
+ this(None, size, openTimeout, idleTimeout, repopulateInterval, defaultUrlOptions, None)
}
def this(size: Int, openTimeout: Duration, idleTimeout: Duration,
@@ -192,7 +274,7 @@ class ThrottledPoolingDatabaseFactory(
}
new ThrottledPoolingDatabase(serviceName, dbhosts, dbname, username, password, finalUrlOptions,
- size, openTimeout, idleTimeout, repopulateInterval)
+ size, openTimeout, idleTimeout, repopulateInterval, failFastPolicy)
}
}
@@ -206,11 +288,13 @@ class ThrottledPoolingDatabase(
numConnections: Int,
val openTimeout: Duration,
idleTimeout: Duration,
- repopulateInterval: Duration) extends Database {
+ repopulateInterval: Duration,
+ val failFastPolicy: FailFastPolicy) extends Database {
Class.forName("com.mysql.jdbc.Driver")
- private val pool = new ThrottledPool(mkConnection, numConnections, openTimeout, idleTimeout)
+ private[database] val pool = new ThrottledPool(mkConnection, numConnections, openTimeout,
+ idleTimeout, failFastPolicy)
private val poolingDataSource = new PoolingDataSource(pool)
poolingDataSource.setAccessToUnderlyingConnectionAllowed(true)
new PoolWatchdogThread(pool, hosts, repopulateInterval).start()
@@ -226,7 +310,7 @@ class ThrottledPoolingDatabase(
extraUrlOptions: Map[String, String], numConnections: Int, openTimeout: Duration,
idleTimeout: Duration, repopulateInterval: Duration) = {
this(None, hosts, name, username, password, extraUrlOptions, numConnections, openTimeout,
- idleTimeout, repopulateInterval)
+ idleTimeout, repopulateInterval, FailFastBasedOnNumConnsPolicy(openTimeout))
}
def open() = {
View
139 ...ous-core/src/test/scala/com/twitter/querulous/integration/ThrottledPoolingDatabaseWithFakeConnSpec.scala
@@ -1,13 +1,26 @@
package com.twitter.querulous.integration
-import com.twitter.util.Time
import com.twitter.conversions.time._
-import com.twitter.querulous.evaluator.StandardQueryEvaluatorFactory
+import com.twitter.querulous.evaluator.{QueryEvaluator, StandardQueryEvaluatorFactory}
import com.twitter.querulous.ConfiguredSpecification
import com.twitter.querulous.sql.{FakeContext, FakeDriver}
import com.mysql.jdbc.exceptions.jdbc4.CommunicationsException
-import com.twitter.querulous.database.{PoolEmptyException, Database, ThrottledPoolingDatabaseFactory}
import com.twitter.querulous.query.{SqlQueryTimeoutException, TimingOutQueryFactory, SqlQueryFactory}
+import collection.immutable.Vector
+import util.Random
+import com.twitter.util.{Duration, Time}
+import com.twitter.querulous.database.{PoolFailFastException, PoolEmptyException, Database, ThrottledPoolingDatabaseFactory}
+import com.twitter.querulous.config.FailFastPolicyConfig
+
+class MockRandom extends Random {
+ private[this] var currIndex = 0
+ private[this] val vals: IndexedSeq[Double] = Vector(0.75, 0.85, 0.25, 0.25, 0.25)
+ // cycle through the list of values specified above
+ override def nextDouble(): Double = {
+ currIndex += 1
+ vals((currIndex - 1)%vals.size)
+ }
+}
object ThrottledPoolingDatabaseWithFakeConnSpec {
// configure repopulation interval to a minute to avoid conn repopulation when test running
@@ -22,25 +35,58 @@ object ThrottledPoolingDatabaseWithFakeConnSpec {
1.second, 100.milliseconds, Map("connectTimeout" -> "2000"))
val testRepopulatedLongConnTimeoutEvaluatorFactory = new StandardQueryEvaluatorFactory(
testRepopulatedLongConnTimeoutDbFactory, testQueryFactory)
+
+ val failFastPolicyConfig = new FailFastPolicyConfig {
+ def highWaterMark: Double = 0.75
+ def openTimeout: Duration = 1.second
+ def lowWaterMark: Double = 0.5
+ def rng: Option[Random] = Some(new MockRandom())
+ }
+ val testFailFastDatabaseFactory = new ThrottledPoolingDatabaseFactory(Some("test"), 8, 1.second,
+ 60.second, 60.seconds, Map.empty, Some(failFastPolicyConfig))
+ val testFailFastEvaluatorFactory = new StandardQueryEvaluatorFactory(testFailFastDatabaseFactory,
+ testQueryFactory)
+
+ val testDatabaseFactoryWithDefaultFailFastPolicy = new ThrottledPoolingDatabaseFactory(
+ Some("test"), 8, 1.second, 60.second, 60.seconds, Map.empty, None)
+ val testEvaluatorFactoryWithDefaultFailFastPolicy = new StandardQueryEvaluatorFactory(
+ testDatabaseFactoryWithDefaultFailFastPolicy, testQueryFactory)
+
+ def destroyConnection(queryEvaluator: QueryEvaluator, host: String, numConns: Int = 1) {
+ FakeContext.markServerDown(host)
+ try {
+ for(i <- 0 until numConns) {
+ try {
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) }
+ assert(false)
+ } catch {
+ case e: CommunicationsException => // expected
+ case t: Throwable => throw t
+ }
+ }
+ } finally {
+ FakeContext.markServerUp(host)
+ }
+ }
}
class ThrottledPoolingDatabaseWithFakeConnSpec extends ConfiguredSpecification {
import ThrottledPoolingDatabaseWithFakeConnSpec._
+ val host = config.hostnames.mkString(",") + "/" + config.database
+ FakeContext.setQueryResult(host, "SELECT 1 FROM DUAL", Array(Array[java.lang.Object](1.asInstanceOf[AnyRef])))
+ FakeContext.setQueryResult(host, "SELECT 2 FROM DUAL", Array(Array[java.lang.Object](2.asInstanceOf[AnyRef])))
+
doBeforeSpec { Database.driverName = FakeDriver.DRIVER_NAME }
"ThrottledJdbcPoolSpec" should {
- val host = config.hostnames.mkString(",") + "/" + config.database
val queryEvaluator = testEvaluatorFactory(config)
-
- FakeContext.setQueryResult(host, "SELECT 1 FROM DUAL", Array(Array[java.lang.Object](1.asInstanceOf[AnyRef])))
- FakeContext.setQueryResult(host, "SELECT 2 FROM DUAL", Array(Array[java.lang.Object](2.asInstanceOf[AnyRef])))
- "execute some queries" >> {
+ "execute some queries" in {
queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
queryEvaluator.select("SELECT 2 FROM DUAL") { r => r.getInt(1) } mustEqual List(2)
}
- "failfast after a host is down" >> {
+ "failfast after a host is down" in {
queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
FakeContext.markServerDown(host)
try {
@@ -53,7 +99,7 @@ class ThrottledPoolingDatabaseWithFakeConnSpec extends ConfiguredSpecification {
}
}
- "failfast after connections are closed due to query timeout" >> {
+ "failfast after connections are closed due to query timeout" in {
queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
FakeContext.setTimeTakenToExecQuery(host, 1.second)
try {
@@ -67,7 +113,7 @@ class ThrottledPoolingDatabaseWithFakeConnSpec extends ConfiguredSpecification {
}
}
- "repopulate the pool every repopulation interval" >> {
+ "repopulate the pool every repopulation interval" in {
val queryEvaluator = testRepopulatedEvaluatorFactory(config)
queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
@@ -84,7 +130,7 @@ class ThrottledPoolingDatabaseWithFakeConnSpec extends ConfiguredSpecification {
}
}
- "repopulate the pool even if it takes longer to establish a connection than repopulation interval" >> {
+ "repopulate the pool even if it takes longer to establish a connection than repopulation interval" in {
val queryEvaluator = testRepopulatedLongConnTimeoutEvaluatorFactory(config)
queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
@@ -104,5 +150,74 @@ class ThrottledPoolingDatabaseWithFakeConnSpec extends ConfiguredSpecification {
}
}
+ "ThrottledJdbcPoolWithFailFastPolicy" should {
+ val queryEvaluator = testFailFastEvaluatorFactory(config)
+
+ "execute query normally as the pool is full or above highWaterMark" in {
+ // number of connections in the pool 8
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
+ destroyConnection(queryEvaluator, host)
+ // number of connections in the pool 7
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
+ }
+
+ // here we should see more aggressive timeout applied on acquiring connection, but currently
+ // it is hard to test that. Nonetheless, this is covered by the unit test.
+ "execute query normally until the pool reaches the lowWaterMark" in {
+ destroyConnection(queryEvaluator, host, 2)
+ // number of connections in the pool 6 = 8 * 0.75, at highWaterMark
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
+ destroyConnection(queryEvaluator, host)
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
+ destroyConnection(queryEvaluator, host)
+ // number of connection is 4 = 8 * 0.5, at lowWaterMark
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
+ }
+
+ "fail fast when the pool is under the lowWaterMark and throw PoolEmptyException when the pool is empty" in {
+ destroyConnection(queryEvaluator, host, 5)
+ // number of connection is at 3, but still fine because the first double out of random number
+ // generator is 0.75, 3 = 8 * 0.5 * 0.75
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
+ // 3 < 8 * 0.5 * 0.5
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } must throwA[PoolFailFastException]
+ destroyConnection(queryEvaluator, host, 3)
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } must throwA[PoolEmptyException]
+ }
+ }
+
+ "ThrottledJdbcPoolWithDefaultFailFastPolicy" should {
+ "execute query normally as the pool until there is no connection in the pool" in {
+ val queryEvaluator = testEvaluatorFactoryWithDefaultFailFastPolicy(config)
+
+ // number of connections in the pool 8
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
+ destroyConnection(queryEvaluator, host)
+ // number of connections in the pool 7
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
+ destroyConnection(queryEvaluator, host)
+ // number of connections in the pool 6
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
+ destroyConnection(queryEvaluator, host)
+ // number of connections in the pool 5
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
+ destroyConnection(queryEvaluator, host)
+ // number of connections in the pool 4
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
+ destroyConnection(queryEvaluator, host)
+ // number of connections in the pool 3
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
+ destroyConnection(queryEvaluator, host)
+ // number of connections in the pool 2
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
+ destroyConnection(queryEvaluator, host)
+ // number of connections in the pool 1
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1)
+ destroyConnection(queryEvaluator, host)
+ // no connection left in the pool
+ queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } must throwA[PoolEmptyException]
+ }
+ }
+
doAfterSpec { Database.driverName = "jdbc:mysql" }
}
View
33 querulous-core/src/test/scala/com/twitter/querulous/unit/DefaultFailFastPolicySpec.scala
@@ -0,0 +1,33 @@
+package com.twitter.querulous.unit
+
+import com.twitter.conversions.time._
+import org.specs.Specification
+import org.specs.mock.{ClassMocker, JMocker}
+import java.sql.Connection
+import com.twitter.querulous.database.{FailFastBasedOnNumConnsPolicy, PoolEmptyException, ThrottledPool}
+
+class DefaultFailFastPolicySpec extends Specification with JMocker with ClassMocker {
+ val conn = mock[Connection]
+ val connFactory = mock[TestConnectionFactory]
+ val ffp = FailFastBasedOnNumConnsPolicy(0, 0, 1.second, None)
+ val pool = mock[ThrottledPool]
+
+ "DefaultFailFastPolicySpec" should {
+ "throw PoolEmptyException when the pool does not have any connection" in {
+ expect {
+ one(pool).getTotal() willReturn 0
+ }
+ ffp.failFast(pool)(connFactory.getConnection(_)) must throwA[PoolEmptyException]
+ }
+
+ "get a connection with the normal timeout setting even if the pool only has one connection" in {
+ expect {
+ one(pool).getTotal() willReturn 1
+ 2.of(pool).size willReturn 8
+ one(pool).timeout willReturn 2.seconds
+ one(connFactory).getConnection(2.seconds) willReturn conn
+ }
+ ffp.failFast(pool)(connFactory.getConnection(_)) mustEqual conn
+ }
+ }
+}
View
101 querulous-core/src/test/scala/com/twitter/querulous/unit/FailFastBasedOnNumConnsPolicySpec.scala
@@ -0,0 +1,101 @@
+package com.twitter.querulous.unit
+
+import com.twitter.conversions.time._
+import org.specs.Specification
+import org.specs.mock.{ClassMocker, JMocker}
+import com.twitter.util.Duration
+import java.sql.Connection
+import java.security.InvalidParameterException
+import com.twitter.querulous.database.{FailFastBasedOnNumConnsPolicy, PoolFailFastException, PoolEmptyException, ThrottledPool}
+
+trait TestConnectionFactory {
+ def getConnection(timeout: Duration): Connection
+}
+
+class FailFastBasedOnNumConnsPolicySpec extends Specification with JMocker with ClassMocker {
+ val conn = mock[Connection]
+ val rng = mock[scala.util.Random]
+ val connFactory = mock[TestConnectionFactory]
+ val ffp = FailFastBasedOnNumConnsPolicy(0.75, 0.5, 1.second, Some(rng))
+ val pool = mock[ThrottledPool]
+
+ "FailFastBasedOnNumConnsPolicySpec" should {
+ "throw InvalidParameterException when highWaterMark is lower than lowWaterMark" in {
+ FailFastBasedOnNumConnsPolicy(0.5, 0.75, 1.second, Some(rng)) must throwA[InvalidParameterException]
+ }
+
+ "throw InvalidParameterException when lowWaterMark is negative or 0" in {
+ FailFastBasedOnNumConnsPolicy(0.5, -0.1, 1.second, Some(rng)) must throwA[InvalidParameterException]
+ }
+
+ "throw InvalidParameterException when hightWaterMark is greater than 1" in {
+ FailFastBasedOnNumConnsPolicy(1.1, 0.75, 1.second, Some(rng)) must throwA[InvalidParameterException]
+ }
+
+ "throw PoolEmptyException when the pool does not have any connection" in {
+ expect {
+ one(pool).getTotal() willReturn 0
+ }
+ ffp.failFast(pool)(connFactory.getConnection(_)) must throwA[PoolEmptyException]
+ }
+
+ "throw PoolFailFastException when pool below lowWaterMark and unlucky" in {
+ expect {
+ one(pool).getTotal() willReturn 2
+ one(pool).size willReturn 8
+ one(rng).nextDouble() willReturn 0.75
+ one(pool).size willReturn 8
+ }
+ ffp.failFast(pool)(connFactory.getConnection(_)) must throwA[PoolFailFastException]
+ }
+
+ "get a connection with the more aggressive timeout setting when pool below lowWaterMark and but lucky enough" in {
+ expect {
+ one(pool).getTotal() willReturn 2
+ one(pool).size willReturn 8
+ one(rng).nextDouble() willReturn 0.25
+ one(pool).size willReturn 8
+ one(connFactory).getConnection(1.second) willReturn conn
+ }
+ ffp.failFast(pool)(connFactory.getConnection(_)) mustEqual conn
+ }
+
+ "get a connection with the more aggressive timeout setting when pool below highWaterMark but at lowWaterMark" in {
+ expect {
+ one(pool).getTotal() willReturn 4
+ 2.of(pool).size willReturn 8
+ one(connFactory).getConnection(1.second) willReturn conn
+ }
+ ffp.failFast(pool)(connFactory.getConnection(_)) mustEqual conn
+ }
+
+ "get a connection with the more aggressive timeout setting when pool below highWaterMark but above lowWaterMark" in {
+ expect {
+ one(pool).getTotal() willReturn 5
+ 2.of(pool).size willReturn 8
+ one(connFactory).getConnection(1.second) willReturn conn
+ }
+ ffp.failFast(pool)(connFactory.getConnection(_)) mustEqual conn
+ }
+
+ "get a connection with the normal timeout setting when pool at highWaterMark" in {
+ expect {
+ one(pool).getTotal() willReturn 6
+ 2.of(pool).size willReturn 8
+ one(pool).timeout willReturn 2.seconds
+ one(connFactory).getConnection(2.seconds) willReturn conn
+ }
+ ffp.failFast(pool)(connFactory.getConnection(_)) mustEqual conn
+ }
+
+ "get a connection with the normal timeout setting when pool above highWaterMark" in {
+ expect {
+ one(pool).getTotal() willReturn 7
+ 2.of(pool).size willReturn 8
+ one(pool).timeout willReturn 2.seconds
+ one(connFactory).getConnection(2.seconds) willReturn conn
+ }
+ ffp.failFast(pool)(connFactory.getConnection(_)) mustEqual conn
+ }
+ }
+}
Please sign in to comment.
Something went wrong with that request. Please try again.