Skip to content
This repository has been archived by the owner on Sep 18, 2021. It is now read-only.

Commit

Permalink
fail fast when the connection pool is under-populated, corresponding …
Browse files Browse the repository at this point in the history
…config, unit and integration tests
  • Loading branch information
yswu committed Aug 17, 2011
1 parent 83ac979 commit f9f4242
Show file tree
Hide file tree
Showing 5 changed files with 376 additions and 34 deletions.
Expand Up @@ -4,14 +4,21 @@ import com.twitter.querulous._
import com.twitter.util.Duration
import com.twitter.conversions.time._
import database._
import util.Random

trait FailFastPolicyConfig {
def highWaterMark: Double
def lowWaterMark: Double
def openTimeout: Duration
def rng: Option[Random]
}

trait PoolingDatabase {
def apply(): DatabaseFactory
}

trait ServiceNameTagged {
def apply(serviceName: Option[String]): DatabaseFactory
trait ServiceNameAndFailFastPolicy {
def apply(serviceName: Option[String], ffp: Option[FailFastPolicyConfig]): DatabaseFactory
}

class ApachePoolingDatabase extends PoolingDatabase {
Expand All @@ -28,19 +35,19 @@ class ApachePoolingDatabase extends PoolingDatabase {
}
}

class ThrottledPoolingDatabase extends PoolingDatabase with ServiceNameTagged {
class ThrottledPoolingDatabase extends PoolingDatabase with ServiceNameAndFailFastPolicy {
var size: Int = 10
var openTimeout: Duration = 50.millis
var repopulateInterval: Duration = 500.millis
var idleTimeout: Duration = 1.minute

def apply() = {
apply(None)
apply(None, None)
}

def apply(serviceName: Option[String]) = {
def apply(serviceName: Option[String], ffp: Option[FailFastPolicyConfig]) = {
new ThrottledPoolingDatabaseFactory(
serviceName, size, openTimeout, idleTimeout, repopulateInterval, Map.empty)
serviceName, size, openTimeout, idleTimeout, repopulateInterval, Map.empty, ffp)
}
}

Expand Down Expand Up @@ -73,14 +80,16 @@ class Database {
var memoize: Boolean = true
var serviceName: Option[String] = None
def serviceName_=(s: String) { serviceName = Some(s) }
var failFastPolicyConfig: Option[FailFastPolicyConfig] = None
def failFastPolicyConfig_=(ffp: FailFastPolicyConfig) { failFastPolicyConfig = Some(ffp) }

def apply(stats: StatsCollector): DatabaseFactory = apply(stats, None)

def apply(stats: StatsCollector, statsFactory: DatabaseFactory => DatabaseFactory): DatabaseFactory = apply(stats, Some(statsFactory))

def apply(stats: StatsCollector, statsFactory: Option[DatabaseFactory => DatabaseFactory]): DatabaseFactory = {
var factory = pool.map{ _ match {
case p: ServiceNameTagged => p(serviceName)
case p: ServiceNameAndFailFastPolicy => p(serviceName, failFastPolicyConfig)
case p: PoolingDatabase => p()
}}.getOrElse(new SingleConnectionDatabaseFactory)

Expand Down
Expand Up @@ -4,14 +4,84 @@ import java.util.concurrent.{TimeUnit, LinkedBlockingQueue}
import java.sql.{SQLException, DriverManager, Connection}
import org.apache.commons.dbcp.{PoolingDataSource, DelegatingConnection}
import org.apache.commons.pool.{PoolableObjectFactory, ObjectPool}
import com.twitter.querulous.config.FailFastPolicyConfig
import com.twitter.util.Duration
import com.twitter.util.Time
import scala.annotation.tailrec
import java.lang.Thread
import java.util.concurrent.atomic.AtomicInteger
import java.security.InvalidParameterException
import util.Random
import java.lang.{UnsupportedOperationException, Thread}

class FailToAcquireConnectionException extends SQLException
class PoolTimeoutException extends FailToAcquireConnectionException
class PoolFailFastException extends FailToAcquireConnectionException
class PoolEmptyException extends PoolFailFastException

/**
* determine whether to fail fast when trying to check out a connection from pool
*/
trait FailFastPolicy {
/**
* This method throws PoolFailFastException when it decides to fail fast given the current state
* of the underlying database, or it could throw PoolTimeoutException when failing to acquire
* a connection within specified time frame
*
* @param db the database from which we are going to open connections
*/
@throws(classOf[FailToAcquireConnectionException])
def failFast(pool: ObjectPool)(f: Duration => Connection): Connection
}

class PoolTimeoutException extends SQLException
class PoolEmptyException extends SQLException
/**
* This policy behaves in the way specified as follows:
* When the number of connections in the pool is below the highWaterMark, start to use the timeout
* passed in when waiting for a connection; when it is below the lowWaterMark, start to fail
* immediately proportional to the number of connections available in the pool, with 100% failure
* rate when the pool is empty
*/
class FailFastBasedOnNumConnsPolicy(val highWaterMark: Double, val lowWaterMark: Double,
val openTimeout: Duration, val rng: Random) extends FailFastPolicy {
if (highWaterMark < lowWaterMark || highWaterMark > 1 || lowWaterMark < 0) {
throw new InvalidParameterException("invalid water mark")
}

@throws(classOf[FailToAcquireConnectionException])
def failFast(pool: ObjectPool)(f: Duration => Connection) = {
pool match {
case p: ThrottledPool => {
val numConn = p.getTotal()
if (numConn == 0) {
throw new PoolEmptyException
} else if (numConn < p.size * lowWaterMark) {
if(numConn < rng.nextDouble() * p.size * lowWaterMark) {
throw new PoolFailFastException
} else {
// should still try to do aggressive timeout at least
f(openTimeout)
}
} else if (numConn < p.size * highWaterMark) {
f(openTimeout)
} else {
f(p.timeout)
}
}
case _ => throw new UnsupportedOperationException("Only support ThrottledPoolingDatabase")
}
}
}

object FailFastBasedOnNumConnsPolicy {
def apply(openTimeout: Duration): FailFastBasedOnNumConnsPolicy = {
apply(0, 0, openTimeout, Some(new Random(System.currentTimeMillis())))
}

def apply(highWaterMark: Double, lowWaterMark: Double, openTimeout: Duration,
rng: Option[Random]): FailFastBasedOnNumConnsPolicy = {
new FailFastBasedOnNumConnsPolicy(highWaterMark, lowWaterMark, openTimeout,
rng.getOrElse(new Random(System.currentTimeMillis())))
}
}

class PooledConnection(c: Connection, p: ObjectPool) extends DelegatingConnection(c) {
private var pool: Option[ObjectPool] = Some(p)
Expand Down Expand Up @@ -43,12 +113,16 @@ class PooledConnection(c: Connection, p: ObjectPool) extends DelegatingConnectio
}
}

class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration,
idleTimeout: Duration) extends ObjectPool {
case class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration,
idleTimeout: Duration, failFastPolicy: FailFastPolicy) extends ObjectPool {
private val pool = new LinkedBlockingQueue[(Connection, Time)]()
private val currentSize = new AtomicInteger(0)
private val numWaiters = new AtomicInteger(0)

def this(factory: () => Connection, size: Int, timeout: Duration, idleTimeout: Duration) = {
this(factory, size, timeout, idleTimeout, FailFastBasedOnNumConnsPolicy(timeout))
}

for (i <- (0.until(size))) addObject()

def addObject() {
Expand All @@ -69,17 +143,17 @@ class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration,
final def borrowObject(): Connection = {
numWaiters.incrementAndGet()
try {
borrowObjectInternal()
failFastPolicy.failFast(this)(borrowObjectInternal)
} finally {
numWaiters.decrementAndGet()
}
}

@tailrec private def borrowObjectInternal(): Connection = {
@tailrec private def borrowObjectInternal(openTimeout: Duration): Connection = {
// short circuit if the pool is empty
if (getTotal() == 0) throw new PoolEmptyException

val pair = pool.poll(timeout.inMillis, TimeUnit.MILLISECONDS)
val pair = pool.poll(openTimeout.inMillis, TimeUnit.MILLISECONDS)
if (pair == null) throw new PoolTimeoutException
val (connection, lastUse) = pair

Expand All @@ -88,7 +162,7 @@ class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration,
try { connection.close() } catch { case _: SQLException => }
// note: dbcp handles object invalidation here.
addObjectIfEmpty()
borrowObjectInternal()
borrowObjectInternal(openTimeout)
} else {
connection
}
Expand Down Expand Up @@ -170,11 +244,19 @@ class ThrottledPoolingDatabaseFactory(
openTimeout: Duration,
idleTimeout: Duration,
repopulateInterval: Duration,
defaultUrlOptions: Map[String, String]) extends DatabaseFactory {
defaultUrlOptions: Map[String, String],
failFastPolicyConfig: Option[FailFastPolicyConfig]) extends DatabaseFactory {

// the default is the one with both highWaterMark and lowWaterMark of 0
// in this case, PoolEmptyException will be thrown when the number of connections in the pool
// is zero; otherwise, it will behave the same way as if this policy is not applied
private val failFastPolicy = failFastPolicyConfig map {pc =>
FailFastBasedOnNumConnsPolicy(pc.highWaterMark, pc.lowWaterMark, pc.openTimeout, pc.rng)
} getOrElse (FailFastBasedOnNumConnsPolicy(openTimeout))

def this(size: Int, openTimeout: Duration, idleTimeout: Duration, repopulateInterval: Duration,
defaultUrlOptions: Map[String, String]) = {
this(None, size, openTimeout, idleTimeout, repopulateInterval, defaultUrlOptions)
this(None, size, openTimeout, idleTimeout, repopulateInterval, defaultUrlOptions, None)
}

def this(size: Int, openTimeout: Duration, idleTimeout: Duration,
Expand All @@ -192,7 +274,7 @@ class ThrottledPoolingDatabaseFactory(
}

new ThrottledPoolingDatabase(serviceName, dbhosts, dbname, username, password, finalUrlOptions,
size, openTimeout, idleTimeout, repopulateInterval)
size, openTimeout, idleTimeout, repopulateInterval, failFastPolicy)
}
}

Expand All @@ -206,11 +288,13 @@ class ThrottledPoolingDatabase(
numConnections: Int,
val openTimeout: Duration,
idleTimeout: Duration,
repopulateInterval: Duration) extends Database {
repopulateInterval: Duration,
val failFastPolicy: FailFastPolicy) extends Database {

Class.forName("com.mysql.jdbc.Driver")

private val pool = new ThrottledPool(mkConnection, numConnections, openTimeout, idleTimeout)
private[database] val pool = new ThrottledPool(mkConnection, numConnections, openTimeout,
idleTimeout, failFastPolicy)
private val poolingDataSource = new PoolingDataSource(pool)
poolingDataSource.setAccessToUnderlyingConnectionAllowed(true)
new PoolWatchdogThread(pool, hosts, repopulateInterval).start()
Expand All @@ -226,7 +310,7 @@ class ThrottledPoolingDatabase(
extraUrlOptions: Map[String, String], numConnections: Int, openTimeout: Duration,
idleTimeout: Duration, repopulateInterval: Duration) = {
this(None, hosts, name, username, password, extraUrlOptions, numConnections, openTimeout,
idleTimeout, repopulateInterval)
idleTimeout, repopulateInterval, FailFastBasedOnNumConnsPolicy(openTimeout))
}

def open() = {
Expand Down

0 comments on commit f9f4242

Please sign in to comment.