diff --git a/querulous-core/src/main/scala/com/twitter/querulous/config/Database.scala b/querulous-core/src/main/scala/com/twitter/querulous/config/Database.scala index ef18933..514092d 100644 --- a/querulous-core/src/main/scala/com/twitter/querulous/config/Database.scala +++ b/querulous-core/src/main/scala/com/twitter/querulous/config/Database.scala @@ -4,14 +4,21 @@ import com.twitter.querulous._ import com.twitter.util.Duration import com.twitter.conversions.time._ import database._ +import util.Random +trait FailFastPolicyConfig { + def highWaterMark: Double + def lowWaterMark: Double + def openTimeout: Duration + def rng: Option[Random] +} trait PoolingDatabase { def apply(): DatabaseFactory } -trait ServiceNameTagged { - def apply(serviceName: Option[String]): DatabaseFactory +trait ServiceNameAndFailFastPolicy { + def apply(serviceName: Option[String], ffp: Option[FailFastPolicyConfig]): DatabaseFactory } class ApachePoolingDatabase extends PoolingDatabase { @@ -28,19 +35,19 @@ class ApachePoolingDatabase extends PoolingDatabase { } } -class ThrottledPoolingDatabase extends PoolingDatabase with ServiceNameTagged { +class ThrottledPoolingDatabase extends PoolingDatabase with ServiceNameAndFailFastPolicy { var size: Int = 10 var openTimeout: Duration = 50.millis var repopulateInterval: Duration = 500.millis var idleTimeout: Duration = 1.minute def apply() = { - apply(None) + apply(None, None) } - def apply(serviceName: Option[String]) = { + def apply(serviceName: Option[String], ffp: Option[FailFastPolicyConfig]) = { new ThrottledPoolingDatabaseFactory( - serviceName, size, openTimeout, idleTimeout, repopulateInterval, Map.empty) + serviceName, size, openTimeout, idleTimeout, repopulateInterval, Map.empty, ffp) } } @@ -73,6 +80,8 @@ class Database { var memoize: Boolean = true var serviceName: Option[String] = None def serviceName_=(s: String) { serviceName = Some(s) } + var failFastPolicyConfig: Option[FailFastPolicyConfig] = None + def failFastPolicyConfig_=(ffp: FailFastPolicyConfig) { failFastPolicyConfig = Some(ffp) } def apply(stats: StatsCollector): DatabaseFactory = apply(stats, None) @@ -80,7 +89,7 @@ class Database { def apply(stats: StatsCollector, statsFactory: Option[DatabaseFactory => DatabaseFactory]): DatabaseFactory = { var factory = pool.map{ _ match { - case p: ServiceNameTagged => p(serviceName) + case p: ServiceNameAndFailFastPolicy => p(serviceName, failFastPolicyConfig) case p: PoolingDatabase => p() }}.getOrElse(new SingleConnectionDatabaseFactory) diff --git a/querulous-core/src/main/scala/com/twitter/querulous/database/ThrottledPoolingDatabase.scala b/querulous-core/src/main/scala/com/twitter/querulous/database/ThrottledPoolingDatabase.scala index c265819..127d333 100644 --- a/querulous-core/src/main/scala/com/twitter/querulous/database/ThrottledPoolingDatabase.scala +++ b/querulous-core/src/main/scala/com/twitter/querulous/database/ThrottledPoolingDatabase.scala @@ -4,14 +4,84 @@ import java.util.concurrent.{TimeUnit, LinkedBlockingQueue} import java.sql.{SQLException, DriverManager, Connection} import org.apache.commons.dbcp.{PoolingDataSource, DelegatingConnection} import org.apache.commons.pool.{PoolableObjectFactory, ObjectPool} +import com.twitter.querulous.config.FailFastPolicyConfig import com.twitter.util.Duration import com.twitter.util.Time import scala.annotation.tailrec -import java.lang.Thread import java.util.concurrent.atomic.AtomicInteger +import java.security.InvalidParameterException +import util.Random +import java.lang.{UnsupportedOperationException, Thread} + +class FailToAcquireConnectionException extends SQLException +class PoolTimeoutException extends FailToAcquireConnectionException +class PoolFailFastException extends FailToAcquireConnectionException +class PoolEmptyException extends PoolFailFastException + +/** + * determine whether to fail fast when trying to check out a connection from pool + */ +trait FailFastPolicy { + /** + * This method throws PoolFailFastException when it decides to fail fast given the current state + * of the underlying database, or it could throw PoolTimeoutException when failing to acquire + * a connection within specified time frame + * + * @param db the database from which we are going to open connections + */ + @throws(classOf[FailToAcquireConnectionException]) + def failFast(pool: ObjectPool)(f: Duration => Connection): Connection +} -class PoolTimeoutException extends SQLException -class PoolEmptyException extends SQLException +/** + * This policy behaves in the way specified as follows: + * When the number of connections in the pool is below the highWaterMark, start to use the timeout + * passed in when waiting for a connection; when it is below the lowWaterMark, start to fail + * immediately proportional to the number of connections available in the pool, with 100% failure + * rate when the pool is empty + */ +class FailFastBasedOnNumConnsPolicy(val highWaterMark: Double, val lowWaterMark: Double, + val openTimeout: Duration, val rng: Random) extends FailFastPolicy { + if (highWaterMark < lowWaterMark || highWaterMark > 1 || lowWaterMark < 0) { + throw new InvalidParameterException("invalid water mark") + } + + @throws(classOf[FailToAcquireConnectionException]) + def failFast(pool: ObjectPool)(f: Duration => Connection) = { + pool match { + case p: ThrottledPool => { + val numConn = p.getTotal() + if (numConn == 0) { + throw new PoolEmptyException + } else if (numConn < p.size * lowWaterMark) { + if(numConn < rng.nextDouble() * p.size * lowWaterMark) { + throw new PoolFailFastException + } else { + // should still try to do aggressive timeout at least + f(openTimeout) + } + } else if (numConn < p.size * highWaterMark) { + f(openTimeout) + } else { + f(p.timeout) + } + } + case _ => throw new UnsupportedOperationException("Only support ThrottledPoolingDatabase") + } + } +} + +object FailFastBasedOnNumConnsPolicy { + def apply(openTimeout: Duration): FailFastBasedOnNumConnsPolicy = { + apply(0, 0, openTimeout, Some(new Random(System.currentTimeMillis()))) + } + + def apply(highWaterMark: Double, lowWaterMark: Double, openTimeout: Duration, + rng: Option[Random]): FailFastBasedOnNumConnsPolicy = { + new FailFastBasedOnNumConnsPolicy(highWaterMark, lowWaterMark, openTimeout, + rng.getOrElse(new Random(System.currentTimeMillis()))) + } +} class PooledConnection(c: Connection, p: ObjectPool) extends DelegatingConnection(c) { private var pool: Option[ObjectPool] = Some(p) @@ -43,12 +113,16 @@ class PooledConnection(c: Connection, p: ObjectPool) extends DelegatingConnectio } } -class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration, - idleTimeout: Duration) extends ObjectPool { +case class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration, + idleTimeout: Duration, failFastPolicy: FailFastPolicy) extends ObjectPool { private val pool = new LinkedBlockingQueue[(Connection, Time)]() private val currentSize = new AtomicInteger(0) private val numWaiters = new AtomicInteger(0) + def this(factory: () => Connection, size: Int, timeout: Duration, idleTimeout: Duration) = { + this(factory, size, timeout, idleTimeout, FailFastBasedOnNumConnsPolicy(timeout)) + } + for (i <- (0.until(size))) addObject() def addObject() { @@ -69,17 +143,17 @@ class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration, final def borrowObject(): Connection = { numWaiters.incrementAndGet() try { - borrowObjectInternal() + failFastPolicy.failFast(this)(borrowObjectInternal) } finally { numWaiters.decrementAndGet() } } - @tailrec private def borrowObjectInternal(): Connection = { + @tailrec private def borrowObjectInternal(openTimeout: Duration): Connection = { // short circuit if the pool is empty if (getTotal() == 0) throw new PoolEmptyException - val pair = pool.poll(timeout.inMillis, TimeUnit.MILLISECONDS) + val pair = pool.poll(openTimeout.inMillis, TimeUnit.MILLISECONDS) if (pair == null) throw new PoolTimeoutException val (connection, lastUse) = pair @@ -88,7 +162,7 @@ class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration, try { connection.close() } catch { case _: SQLException => } // note: dbcp handles object invalidation here. addObjectIfEmpty() - borrowObjectInternal() + borrowObjectInternal(openTimeout) } else { connection } @@ -170,11 +244,19 @@ class ThrottledPoolingDatabaseFactory( openTimeout: Duration, idleTimeout: Duration, repopulateInterval: Duration, - defaultUrlOptions: Map[String, String]) extends DatabaseFactory { + defaultUrlOptions: Map[String, String], + failFastPolicyConfig: Option[FailFastPolicyConfig]) extends DatabaseFactory { + + // the default is the one with both highWaterMark and lowWaterMark of 0 + // in this case, PoolEmptyException will be thrown when the number of connections in the pool + // is zero; otherwise, it will behave the same way as if this policy is not applied + private val failFastPolicy = failFastPolicyConfig map {pc => + FailFastBasedOnNumConnsPolicy(pc.highWaterMark, pc.lowWaterMark, pc.openTimeout, pc.rng) + } getOrElse (FailFastBasedOnNumConnsPolicy(openTimeout)) def this(size: Int, openTimeout: Duration, idleTimeout: Duration, repopulateInterval: Duration, defaultUrlOptions: Map[String, String]) = { - this(None, size, openTimeout, idleTimeout, repopulateInterval, defaultUrlOptions) + this(None, size, openTimeout, idleTimeout, repopulateInterval, defaultUrlOptions, None) } def this(size: Int, openTimeout: Duration, idleTimeout: Duration, @@ -192,7 +274,7 @@ class ThrottledPoolingDatabaseFactory( } new ThrottledPoolingDatabase(serviceName, dbhosts, dbname, username, password, finalUrlOptions, - size, openTimeout, idleTimeout, repopulateInterval) + size, openTimeout, idleTimeout, repopulateInterval, failFastPolicy) } } @@ -206,11 +288,13 @@ class ThrottledPoolingDatabase( numConnections: Int, val openTimeout: Duration, idleTimeout: Duration, - repopulateInterval: Duration) extends Database { + repopulateInterval: Duration, + val failFastPolicy: FailFastPolicy) extends Database { Class.forName("com.mysql.jdbc.Driver") - private val pool = new ThrottledPool(mkConnection, numConnections, openTimeout, idleTimeout) + private[database] val pool = new ThrottledPool(mkConnection, numConnections, openTimeout, + idleTimeout, failFastPolicy) private val poolingDataSource = new PoolingDataSource(pool) poolingDataSource.setAccessToUnderlyingConnectionAllowed(true) new PoolWatchdogThread(pool, hosts, repopulateInterval).start() @@ -226,7 +310,7 @@ class ThrottledPoolingDatabase( extraUrlOptions: Map[String, String], numConnections: Int, openTimeout: Duration, idleTimeout: Duration, repopulateInterval: Duration) = { this(None, hosts, name, username, password, extraUrlOptions, numConnections, openTimeout, - idleTimeout, repopulateInterval) + idleTimeout, repopulateInterval, FailFastBasedOnNumConnsPolicy(openTimeout)) } def open() = { diff --git a/querulous-core/src/test/scala/com/twitter/querulous/integration/ThrottledPoolingDatabaseWithFakeConnSpec.scala b/querulous-core/src/test/scala/com/twitter/querulous/integration/ThrottledPoolingDatabaseWithFakeConnSpec.scala index c7c2b17..bba5346 100644 --- a/querulous-core/src/test/scala/com/twitter/querulous/integration/ThrottledPoolingDatabaseWithFakeConnSpec.scala +++ b/querulous-core/src/test/scala/com/twitter/querulous/integration/ThrottledPoolingDatabaseWithFakeConnSpec.scala @@ -1,13 +1,26 @@ package com.twitter.querulous.integration -import com.twitter.util.Time import com.twitter.conversions.time._ -import com.twitter.querulous.evaluator.StandardQueryEvaluatorFactory +import com.twitter.querulous.evaluator.{QueryEvaluator, StandardQueryEvaluatorFactory} import com.twitter.querulous.ConfiguredSpecification import com.twitter.querulous.sql.{FakeContext, FakeDriver} import com.mysql.jdbc.exceptions.jdbc4.CommunicationsException -import com.twitter.querulous.database.{PoolEmptyException, Database, ThrottledPoolingDatabaseFactory} import com.twitter.querulous.query.{SqlQueryTimeoutException, TimingOutQueryFactory, SqlQueryFactory} +import collection.immutable.Vector +import util.Random +import com.twitter.util.{Duration, Time} +import com.twitter.querulous.database.{PoolFailFastException, PoolEmptyException, Database, ThrottledPoolingDatabaseFactory} +import com.twitter.querulous.config.FailFastPolicyConfig + +class MockRandom extends Random { + private[this] var currIndex = 0 + private[this] val vals: IndexedSeq[Double] = Vector(0.75, 0.85, 0.25, 0.25, 0.25) + // cycle through the list of values specified above + override def nextDouble(): Double = { + currIndex += 1 + vals((currIndex - 1)%vals.size) + } +} object ThrottledPoolingDatabaseWithFakeConnSpec { // configure repopulation interval to a minute to avoid conn repopulation when test running @@ -22,25 +35,58 @@ object ThrottledPoolingDatabaseWithFakeConnSpec { 1.second, 100.milliseconds, Map("connectTimeout" -> "2000")) val testRepopulatedLongConnTimeoutEvaluatorFactory = new StandardQueryEvaluatorFactory( testRepopulatedLongConnTimeoutDbFactory, testQueryFactory) + + val failFastPolicyConfig = new FailFastPolicyConfig { + def highWaterMark: Double = 0.75 + def openTimeout: Duration = 1.second + def lowWaterMark: Double = 0.5 + def rng: Option[Random] = Some(new MockRandom()) + } + val testFailFastDatabaseFactory = new ThrottledPoolingDatabaseFactory(Some("test"), 8, 1.second, + 60.second, 60.seconds, Map.empty, Some(failFastPolicyConfig)) + val testFailFastEvaluatorFactory = new StandardQueryEvaluatorFactory(testFailFastDatabaseFactory, + testQueryFactory) + + val testDatabaseFactoryWithDefaultFailFastPolicy = new ThrottledPoolingDatabaseFactory( + Some("test"), 8, 1.second, 60.second, 60.seconds, Map.empty, None) + val testEvaluatorFactoryWithDefaultFailFastPolicy = new StandardQueryEvaluatorFactory( + testDatabaseFactoryWithDefaultFailFastPolicy, testQueryFactory) + + def destroyConnection(queryEvaluator: QueryEvaluator, host: String, numConns: Int = 1) { + FakeContext.markServerDown(host) + try { + for(i <- 0 until numConns) { + try { + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } + assert(false) + } catch { + case e: CommunicationsException => // expected + case t: Throwable => throw t + } + } + } finally { + FakeContext.markServerUp(host) + } + } } class ThrottledPoolingDatabaseWithFakeConnSpec extends ConfiguredSpecification { import ThrottledPoolingDatabaseWithFakeConnSpec._ + val host = config.hostnames.mkString(",") + "/" + config.database + FakeContext.setQueryResult(host, "SELECT 1 FROM DUAL", Array(Array[java.lang.Object](1.asInstanceOf[AnyRef]))) + FakeContext.setQueryResult(host, "SELECT 2 FROM DUAL", Array(Array[java.lang.Object](2.asInstanceOf[AnyRef]))) + doBeforeSpec { Database.driverName = FakeDriver.DRIVER_NAME } "ThrottledJdbcPoolSpec" should { - val host = config.hostnames.mkString(",") + "/" + config.database val queryEvaluator = testEvaluatorFactory(config) - - FakeContext.setQueryResult(host, "SELECT 1 FROM DUAL", Array(Array[java.lang.Object](1.asInstanceOf[AnyRef]))) - FakeContext.setQueryResult(host, "SELECT 2 FROM DUAL", Array(Array[java.lang.Object](2.asInstanceOf[AnyRef]))) - "execute some queries" >> { + "execute some queries" in { queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) queryEvaluator.select("SELECT 2 FROM DUAL") { r => r.getInt(1) } mustEqual List(2) } - "failfast after a host is down" >> { + "failfast after a host is down" in { queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) FakeContext.markServerDown(host) try { @@ -53,7 +99,7 @@ class ThrottledPoolingDatabaseWithFakeConnSpec extends ConfiguredSpecification { } } - "failfast after connections are closed due to query timeout" >> { + "failfast after connections are closed due to query timeout" in { queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) FakeContext.setTimeTakenToExecQuery(host, 1.second) try { @@ -67,7 +113,7 @@ class ThrottledPoolingDatabaseWithFakeConnSpec extends ConfiguredSpecification { } } - "repopulate the pool every repopulation interval" >> { + "repopulate the pool every repopulation interval" in { val queryEvaluator = testRepopulatedEvaluatorFactory(config) queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) @@ -84,7 +130,7 @@ class ThrottledPoolingDatabaseWithFakeConnSpec extends ConfiguredSpecification { } } - "repopulate the pool even if it takes longer to establish a connection than repopulation interval" >> { + "repopulate the pool even if it takes longer to establish a connection than repopulation interval" in { val queryEvaluator = testRepopulatedLongConnTimeoutEvaluatorFactory(config) queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) @@ -104,5 +150,74 @@ class ThrottledPoolingDatabaseWithFakeConnSpec extends ConfiguredSpecification { } } + "ThrottledJdbcPoolWithFailFastPolicy" should { + val queryEvaluator = testFailFastEvaluatorFactory(config) + + "execute query normally as the pool is full or above highWaterMark" in { + // number of connections in the pool 8 + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) + destroyConnection(queryEvaluator, host) + // number of connections in the pool 7 + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) + } + + // here we should see more aggressive timeout applied on acquiring connection, but currently + // it is hard to test that. Nonetheless, this is covered by the unit test. + "execute query normally until the pool reaches the lowWaterMark" in { + destroyConnection(queryEvaluator, host, 2) + // number of connections in the pool 6 = 8 * 0.75, at highWaterMark + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) + destroyConnection(queryEvaluator, host) + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) + destroyConnection(queryEvaluator, host) + // number of connection is 4 = 8 * 0.5, at lowWaterMark + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) + } + + "fail fast when the pool is under the lowWaterMark and throw PoolEmptyException when the pool is empty" in { + destroyConnection(queryEvaluator, host, 5) + // number of connection is at 3, but still fine because the first double out of random number + // generator is 0.75, 3 = 8 * 0.5 * 0.75 + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) + // 3 < 8 * 0.5 * 0.5 + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } must throwA[PoolFailFastException] + destroyConnection(queryEvaluator, host, 3) + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } must throwA[PoolEmptyException] + } + } + + "ThrottledJdbcPoolWithDefaultFailFastPolicy" should { + "execute query normally as the pool until there is no connection in the pool" in { + val queryEvaluator = testEvaluatorFactoryWithDefaultFailFastPolicy(config) + + // number of connections in the pool 8 + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) + destroyConnection(queryEvaluator, host) + // number of connections in the pool 7 + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) + destroyConnection(queryEvaluator, host) + // number of connections in the pool 6 + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) + destroyConnection(queryEvaluator, host) + // number of connections in the pool 5 + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) + destroyConnection(queryEvaluator, host) + // number of connections in the pool 4 + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) + destroyConnection(queryEvaluator, host) + // number of connections in the pool 3 + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) + destroyConnection(queryEvaluator, host) + // number of connections in the pool 2 + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) + destroyConnection(queryEvaluator, host) + // number of connections in the pool 1 + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } mustEqual List(1) + destroyConnection(queryEvaluator, host) + // no connection left in the pool + queryEvaluator.select("SELECT 1 FROM DUAL") { r => r.getInt(1) } must throwA[PoolEmptyException] + } + } + doAfterSpec { Database.driverName = "jdbc:mysql" } } diff --git a/querulous-core/src/test/scala/com/twitter/querulous/unit/DefaultFailFastPolicySpec.scala b/querulous-core/src/test/scala/com/twitter/querulous/unit/DefaultFailFastPolicySpec.scala new file mode 100644 index 0000000..9d52fd6 --- /dev/null +++ b/querulous-core/src/test/scala/com/twitter/querulous/unit/DefaultFailFastPolicySpec.scala @@ -0,0 +1,33 @@ +package com.twitter.querulous.unit + +import com.twitter.conversions.time._ +import org.specs.Specification +import org.specs.mock.{ClassMocker, JMocker} +import java.sql.Connection +import com.twitter.querulous.database.{FailFastBasedOnNumConnsPolicy, PoolEmptyException, ThrottledPool} + +class DefaultFailFastPolicySpec extends Specification with JMocker with ClassMocker { + val conn = mock[Connection] + val connFactory = mock[TestConnectionFactory] + val ffp = FailFastBasedOnNumConnsPolicy(0, 0, 1.second, None) + val pool = mock[ThrottledPool] + + "DefaultFailFastPolicySpec" should { + "throw PoolEmptyException when the pool does not have any connection" in { + expect { + one(pool).getTotal() willReturn 0 + } + ffp.failFast(pool)(connFactory.getConnection(_)) must throwA[PoolEmptyException] + } + + "get a connection with the normal timeout setting even if the pool only has one connection" in { + expect { + one(pool).getTotal() willReturn 1 + 2.of(pool).size willReturn 8 + one(pool).timeout willReturn 2.seconds + one(connFactory).getConnection(2.seconds) willReturn conn + } + ffp.failFast(pool)(connFactory.getConnection(_)) mustEqual conn + } + } +} diff --git a/querulous-core/src/test/scala/com/twitter/querulous/unit/FailFastBasedOnNumConnsPolicySpec.scala b/querulous-core/src/test/scala/com/twitter/querulous/unit/FailFastBasedOnNumConnsPolicySpec.scala new file mode 100644 index 0000000..b660240 --- /dev/null +++ b/querulous-core/src/test/scala/com/twitter/querulous/unit/FailFastBasedOnNumConnsPolicySpec.scala @@ -0,0 +1,101 @@ +package com.twitter.querulous.unit + +import com.twitter.conversions.time._ +import org.specs.Specification +import org.specs.mock.{ClassMocker, JMocker} +import com.twitter.util.Duration +import java.sql.Connection +import java.security.InvalidParameterException +import com.twitter.querulous.database.{FailFastBasedOnNumConnsPolicy, PoolFailFastException, PoolEmptyException, ThrottledPool} + +trait TestConnectionFactory { + def getConnection(timeout: Duration): Connection +} + +class FailFastBasedOnNumConnsPolicySpec extends Specification with JMocker with ClassMocker { + val conn = mock[Connection] + val rng = mock[scala.util.Random] + val connFactory = mock[TestConnectionFactory] + val ffp = FailFastBasedOnNumConnsPolicy(0.75, 0.5, 1.second, Some(rng)) + val pool = mock[ThrottledPool] + + "FailFastBasedOnNumConnsPolicySpec" should { + "throw InvalidParameterException when highWaterMark is lower than lowWaterMark" in { + FailFastBasedOnNumConnsPolicy(0.5, 0.75, 1.second, Some(rng)) must throwA[InvalidParameterException] + } + + "throw InvalidParameterException when lowWaterMark is negative or 0" in { + FailFastBasedOnNumConnsPolicy(0.5, -0.1, 1.second, Some(rng)) must throwA[InvalidParameterException] + } + + "throw InvalidParameterException when hightWaterMark is greater than 1" in { + FailFastBasedOnNumConnsPolicy(1.1, 0.75, 1.second, Some(rng)) must throwA[InvalidParameterException] + } + + "throw PoolEmptyException when the pool does not have any connection" in { + expect { + one(pool).getTotal() willReturn 0 + } + ffp.failFast(pool)(connFactory.getConnection(_)) must throwA[PoolEmptyException] + } + + "throw PoolFailFastException when pool below lowWaterMark and unlucky" in { + expect { + one(pool).getTotal() willReturn 2 + one(pool).size willReturn 8 + one(rng).nextDouble() willReturn 0.75 + one(pool).size willReturn 8 + } + ffp.failFast(pool)(connFactory.getConnection(_)) must throwA[PoolFailFastException] + } + + "get a connection with the more aggressive timeout setting when pool below lowWaterMark and but lucky enough" in { + expect { + one(pool).getTotal() willReturn 2 + one(pool).size willReturn 8 + one(rng).nextDouble() willReturn 0.25 + one(pool).size willReturn 8 + one(connFactory).getConnection(1.second) willReturn conn + } + ffp.failFast(pool)(connFactory.getConnection(_)) mustEqual conn + } + + "get a connection with the more aggressive timeout setting when pool below highWaterMark but at lowWaterMark" in { + expect { + one(pool).getTotal() willReturn 4 + 2.of(pool).size willReturn 8 + one(connFactory).getConnection(1.second) willReturn conn + } + ffp.failFast(pool)(connFactory.getConnection(_)) mustEqual conn + } + + "get a connection with the more aggressive timeout setting when pool below highWaterMark but above lowWaterMark" in { + expect { + one(pool).getTotal() willReturn 5 + 2.of(pool).size willReturn 8 + one(connFactory).getConnection(1.second) willReturn conn + } + ffp.failFast(pool)(connFactory.getConnection(_)) mustEqual conn + } + + "get a connection with the normal timeout setting when pool at highWaterMark" in { + expect { + one(pool).getTotal() willReturn 6 + 2.of(pool).size willReturn 8 + one(pool).timeout willReturn 2.seconds + one(connFactory).getConnection(2.seconds) willReturn conn + } + ffp.failFast(pool)(connFactory.getConnection(_)) mustEqual conn + } + + "get a connection with the normal timeout setting when pool above highWaterMark" in { + expect { + one(pool).getTotal() willReturn 7 + 2.of(pool).size willReturn 8 + one(pool).timeout willReturn 2.seconds + one(connFactory).getConnection(2.seconds) willReturn conn + } + ffp.failFast(pool)(connFactory.getConnection(_)) mustEqual conn + } + } +}