Skip to content
This repository has been archived by the owner on Sep 18, 2021. It is now read-only.

Commit

Permalink
fail fast when the connection pool is under-populated, corresponding …
Browse files Browse the repository at this point in the history
…config, unit and integration tests
  • Loading branch information
yswu committed Aug 17, 2011
1 parent 83ac979 commit f9f4242
Show file tree
Hide file tree
Showing 5 changed files with 376 additions and 34 deletions.
Expand Up @@ -4,14 +4,21 @@ import com.twitter.querulous._
import com.twitter.util.Duration import com.twitter.util.Duration
import com.twitter.conversions.time._ import com.twitter.conversions.time._
import database._ import database._
import util.Random


trait FailFastPolicyConfig {
def highWaterMark: Double
def lowWaterMark: Double
def openTimeout: Duration
def rng: Option[Random]
}


trait PoolingDatabase { trait PoolingDatabase {
def apply(): DatabaseFactory def apply(): DatabaseFactory
} }


trait ServiceNameTagged { trait ServiceNameAndFailFastPolicy {
def apply(serviceName: Option[String]): DatabaseFactory def apply(serviceName: Option[String], ffp: Option[FailFastPolicyConfig]): DatabaseFactory
} }


class ApachePoolingDatabase extends PoolingDatabase { class ApachePoolingDatabase extends PoolingDatabase {
Expand All @@ -28,19 +35,19 @@ class ApachePoolingDatabase extends PoolingDatabase {
} }
} }


class ThrottledPoolingDatabase extends PoolingDatabase with ServiceNameTagged { class ThrottledPoolingDatabase extends PoolingDatabase with ServiceNameAndFailFastPolicy {
var size: Int = 10 var size: Int = 10
var openTimeout: Duration = 50.millis var openTimeout: Duration = 50.millis
var repopulateInterval: Duration = 500.millis var repopulateInterval: Duration = 500.millis
var idleTimeout: Duration = 1.minute var idleTimeout: Duration = 1.minute


def apply() = { def apply() = {
apply(None) apply(None, None)
} }


def apply(serviceName: Option[String]) = { def apply(serviceName: Option[String], ffp: Option[FailFastPolicyConfig]) = {
new ThrottledPoolingDatabaseFactory( new ThrottledPoolingDatabaseFactory(
serviceName, size, openTimeout, idleTimeout, repopulateInterval, Map.empty) serviceName, size, openTimeout, idleTimeout, repopulateInterval, Map.empty, ffp)
} }
} }


Expand Down Expand Up @@ -73,14 +80,16 @@ class Database {
var memoize: Boolean = true var memoize: Boolean = true
var serviceName: Option[String] = None var serviceName: Option[String] = None
def serviceName_=(s: String) { serviceName = Some(s) } def serviceName_=(s: String) { serviceName = Some(s) }
var failFastPolicyConfig: Option[FailFastPolicyConfig] = None
def failFastPolicyConfig_=(ffp: FailFastPolicyConfig) { failFastPolicyConfig = Some(ffp) }


def apply(stats: StatsCollector): DatabaseFactory = apply(stats, None) def apply(stats: StatsCollector): DatabaseFactory = apply(stats, None)


def apply(stats: StatsCollector, statsFactory: DatabaseFactory => DatabaseFactory): DatabaseFactory = apply(stats, Some(statsFactory)) def apply(stats: StatsCollector, statsFactory: DatabaseFactory => DatabaseFactory): DatabaseFactory = apply(stats, Some(statsFactory))


def apply(stats: StatsCollector, statsFactory: Option[DatabaseFactory => DatabaseFactory]): DatabaseFactory = { def apply(stats: StatsCollector, statsFactory: Option[DatabaseFactory => DatabaseFactory]): DatabaseFactory = {
var factory = pool.map{ _ match { var factory = pool.map{ _ match {
case p: ServiceNameTagged => p(serviceName) case p: ServiceNameAndFailFastPolicy => p(serviceName, failFastPolicyConfig)
case p: PoolingDatabase => p() case p: PoolingDatabase => p()
}}.getOrElse(new SingleConnectionDatabaseFactory) }}.getOrElse(new SingleConnectionDatabaseFactory)


Expand Down
Expand Up @@ -4,14 +4,84 @@ import java.util.concurrent.{TimeUnit, LinkedBlockingQueue}
import java.sql.{SQLException, DriverManager, Connection} import java.sql.{SQLException, DriverManager, Connection}
import org.apache.commons.dbcp.{PoolingDataSource, DelegatingConnection} import org.apache.commons.dbcp.{PoolingDataSource, DelegatingConnection}
import org.apache.commons.pool.{PoolableObjectFactory, ObjectPool} import org.apache.commons.pool.{PoolableObjectFactory, ObjectPool}
import com.twitter.querulous.config.FailFastPolicyConfig
import com.twitter.util.Duration import com.twitter.util.Duration
import com.twitter.util.Time import com.twitter.util.Time
import scala.annotation.tailrec import scala.annotation.tailrec
import java.lang.Thread
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import java.security.InvalidParameterException
import util.Random
import java.lang.{UnsupportedOperationException, Thread}

class FailToAcquireConnectionException extends SQLException
class PoolTimeoutException extends FailToAcquireConnectionException
class PoolFailFastException extends FailToAcquireConnectionException
class PoolEmptyException extends PoolFailFastException

/**
* determine whether to fail fast when trying to check out a connection from pool
*/
trait FailFastPolicy {
/**
* This method throws PoolFailFastException when it decides to fail fast given the current state
* of the underlying database, or it could throw PoolTimeoutException when failing to acquire
* a connection within specified time frame
*
* @param db the database from which we are going to open connections
*/
@throws(classOf[FailToAcquireConnectionException])
def failFast(pool: ObjectPool)(f: Duration => Connection): Connection
}


class PoolTimeoutException extends SQLException /**
class PoolEmptyException extends SQLException * This policy behaves in the way specified as follows:
* When the number of connections in the pool is below the highWaterMark, start to use the timeout
* passed in when waiting for a connection; when it is below the lowWaterMark, start to fail
* immediately proportional to the number of connections available in the pool, with 100% failure
* rate when the pool is empty
*/
class FailFastBasedOnNumConnsPolicy(val highWaterMark: Double, val lowWaterMark: Double,
val openTimeout: Duration, val rng: Random) extends FailFastPolicy {
if (highWaterMark < lowWaterMark || highWaterMark > 1 || lowWaterMark < 0) {
throw new InvalidParameterException("invalid water mark")
}

@throws(classOf[FailToAcquireConnectionException])
def failFast(pool: ObjectPool)(f: Duration => Connection) = {
pool match {
case p: ThrottledPool => {
val numConn = p.getTotal()
if (numConn == 0) {
throw new PoolEmptyException
} else if (numConn < p.size * lowWaterMark) {
if(numConn < rng.nextDouble() * p.size * lowWaterMark) {
throw new PoolFailFastException
} else {
// should still try to do aggressive timeout at least
f(openTimeout)
}
} else if (numConn < p.size * highWaterMark) {
f(openTimeout)
} else {
f(p.timeout)
}
}
case _ => throw new UnsupportedOperationException("Only support ThrottledPoolingDatabase")
}
}
}

object FailFastBasedOnNumConnsPolicy {
def apply(openTimeout: Duration): FailFastBasedOnNumConnsPolicy = {
apply(0, 0, openTimeout, Some(new Random(System.currentTimeMillis())))
}

def apply(highWaterMark: Double, lowWaterMark: Double, openTimeout: Duration,
rng: Option[Random]): FailFastBasedOnNumConnsPolicy = {
new FailFastBasedOnNumConnsPolicy(highWaterMark, lowWaterMark, openTimeout,
rng.getOrElse(new Random(System.currentTimeMillis())))
}
}


class PooledConnection(c: Connection, p: ObjectPool) extends DelegatingConnection(c) { class PooledConnection(c: Connection, p: ObjectPool) extends DelegatingConnection(c) {
private var pool: Option[ObjectPool] = Some(p) private var pool: Option[ObjectPool] = Some(p)
Expand Down Expand Up @@ -43,12 +113,16 @@ class PooledConnection(c: Connection, p: ObjectPool) extends DelegatingConnectio
} }
} }


class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration, case class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration,
idleTimeout: Duration) extends ObjectPool { idleTimeout: Duration, failFastPolicy: FailFastPolicy) extends ObjectPool {
private val pool = new LinkedBlockingQueue[(Connection, Time)]() private val pool = new LinkedBlockingQueue[(Connection, Time)]()
private val currentSize = new AtomicInteger(0) private val currentSize = new AtomicInteger(0)
private val numWaiters = new AtomicInteger(0) private val numWaiters = new AtomicInteger(0)


def this(factory: () => Connection, size: Int, timeout: Duration, idleTimeout: Duration) = {
this(factory, size, timeout, idleTimeout, FailFastBasedOnNumConnsPolicy(timeout))
}

for (i <- (0.until(size))) addObject() for (i <- (0.until(size))) addObject()


def addObject() { def addObject() {
Expand All @@ -69,17 +143,17 @@ class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration,
final def borrowObject(): Connection = { final def borrowObject(): Connection = {
numWaiters.incrementAndGet() numWaiters.incrementAndGet()
try { try {
borrowObjectInternal() failFastPolicy.failFast(this)(borrowObjectInternal)
} finally { } finally {
numWaiters.decrementAndGet() numWaiters.decrementAndGet()
} }
} }


@tailrec private def borrowObjectInternal(): Connection = { @tailrec private def borrowObjectInternal(openTimeout: Duration): Connection = {
// short circuit if the pool is empty // short circuit if the pool is empty
if (getTotal() == 0) throw new PoolEmptyException if (getTotal() == 0) throw new PoolEmptyException


val pair = pool.poll(timeout.inMillis, TimeUnit.MILLISECONDS) val pair = pool.poll(openTimeout.inMillis, TimeUnit.MILLISECONDS)
if (pair == null) throw new PoolTimeoutException if (pair == null) throw new PoolTimeoutException
val (connection, lastUse) = pair val (connection, lastUse) = pair


Expand All @@ -88,7 +162,7 @@ class ThrottledPool(factory: () => Connection, val size: Int, timeout: Duration,
try { connection.close() } catch { case _: SQLException => } try { connection.close() } catch { case _: SQLException => }
// note: dbcp handles object invalidation here. // note: dbcp handles object invalidation here.
addObjectIfEmpty() addObjectIfEmpty()
borrowObjectInternal() borrowObjectInternal(openTimeout)
} else { } else {
connection connection
} }
Expand Down Expand Up @@ -170,11 +244,19 @@ class ThrottledPoolingDatabaseFactory(
openTimeout: Duration, openTimeout: Duration,
idleTimeout: Duration, idleTimeout: Duration,
repopulateInterval: Duration, repopulateInterval: Duration,
defaultUrlOptions: Map[String, String]) extends DatabaseFactory { defaultUrlOptions: Map[String, String],
failFastPolicyConfig: Option[FailFastPolicyConfig]) extends DatabaseFactory {

// the default is the one with both highWaterMark and lowWaterMark of 0
// in this case, PoolEmptyException will be thrown when the number of connections in the pool
// is zero; otherwise, it will behave the same way as if this policy is not applied
private val failFastPolicy = failFastPolicyConfig map {pc =>
FailFastBasedOnNumConnsPolicy(pc.highWaterMark, pc.lowWaterMark, pc.openTimeout, pc.rng)
} getOrElse (FailFastBasedOnNumConnsPolicy(openTimeout))


def this(size: Int, openTimeout: Duration, idleTimeout: Duration, repopulateInterval: Duration, def this(size: Int, openTimeout: Duration, idleTimeout: Duration, repopulateInterval: Duration,
defaultUrlOptions: Map[String, String]) = { defaultUrlOptions: Map[String, String]) = {
this(None, size, openTimeout, idleTimeout, repopulateInterval, defaultUrlOptions) this(None, size, openTimeout, idleTimeout, repopulateInterval, defaultUrlOptions, None)
} }


def this(size: Int, openTimeout: Duration, idleTimeout: Duration, def this(size: Int, openTimeout: Duration, idleTimeout: Duration,
Expand All @@ -192,7 +274,7 @@ class ThrottledPoolingDatabaseFactory(
} }


new ThrottledPoolingDatabase(serviceName, dbhosts, dbname, username, password, finalUrlOptions, new ThrottledPoolingDatabase(serviceName, dbhosts, dbname, username, password, finalUrlOptions,
size, openTimeout, idleTimeout, repopulateInterval) size, openTimeout, idleTimeout, repopulateInterval, failFastPolicy)
} }
} }


Expand All @@ -206,11 +288,13 @@ class ThrottledPoolingDatabase(
numConnections: Int, numConnections: Int,
val openTimeout: Duration, val openTimeout: Duration,
idleTimeout: Duration, idleTimeout: Duration,
repopulateInterval: Duration) extends Database { repopulateInterval: Duration,
val failFastPolicy: FailFastPolicy) extends Database {


Class.forName("com.mysql.jdbc.Driver") Class.forName("com.mysql.jdbc.Driver")


private val pool = new ThrottledPool(mkConnection, numConnections, openTimeout, idleTimeout) private[database] val pool = new ThrottledPool(mkConnection, numConnections, openTimeout,
idleTimeout, failFastPolicy)
private val poolingDataSource = new PoolingDataSource(pool) private val poolingDataSource = new PoolingDataSource(pool)
poolingDataSource.setAccessToUnderlyingConnectionAllowed(true) poolingDataSource.setAccessToUnderlyingConnectionAllowed(true)
new PoolWatchdogThread(pool, hosts, repopulateInterval).start() new PoolWatchdogThread(pool, hosts, repopulateInterval).start()
Expand All @@ -226,7 +310,7 @@ class ThrottledPoolingDatabase(
extraUrlOptions: Map[String, String], numConnections: Int, openTimeout: Duration, extraUrlOptions: Map[String, String], numConnections: Int, openTimeout: Duration,
idleTimeout: Duration, repopulateInterval: Duration) = { idleTimeout: Duration, repopulateInterval: Duration) = {
this(None, hosts, name, username, password, extraUrlOptions, numConnections, openTimeout, this(None, hosts, name, username, password, extraUrlOptions, numConnections, openTimeout,
idleTimeout, repopulateInterval) idleTimeout, repopulateInterval, FailFastBasedOnNumConnsPolicy(openTimeout))
} }


def open() = { def open() = {
Expand Down

0 comments on commit f9f4242

Please sign in to comment.