Permalink
Browse files

New ServerBuilder style

project.version=1.2.5
  • Loading branch information...
1 parent 71a2239 commit 62f760f2eee4a782b61fa647b5b7af1050ca6f1a Nick Kallen committed Mar 24, 2011
@@ -3,19 +3,17 @@ package com.twitter.finagle.builder
import scala.collection.mutable.HashSet
import scala.collection.JavaConversions._
-import java.util.concurrent.{Executors, LinkedBlockingQueue}
+import java.util.concurrent.Executors
import java.util.logging.Logger
import java.net.SocketAddress
-import javax.net.ssl.SSLContext
import org.jboss.netty.bootstrap.ServerBootstrap
import org.jboss.netty.channel._
import org.jboss.netty.channel.socket.nio._
-import org.jboss.netty.handler.timeout.IdleStateHandler
import org.jboss.netty.handler.ssl._
import org.jboss.netty.handler.timeout.ReadTimeoutHandler
-import com.twitter.util.{Time, Duration}
+import com.twitter.util.Duration
import com.twitter.conversions.time._
import com.twitter.finagle._
@@ -26,7 +24,7 @@ import com.twitter.finagle.tracing.{TraceReceiver, TracingFilter, NullTraceRecei
import com.twitter.finagle.util.Conversions._
import com.twitter.finagle.util._
import com.twitter.finagle.util.Timer._
-import com.twitter.util.{Future, Promise, Return, Throw}
+import com.twitter.util.Future
import channel.{ChannelClosingHandler, ServiceToChannelHandler, ChannelSemaphoreHandler}
import service.{ExpiringService, TimeoutFilter, StatsFilter}
@@ -42,7 +40,7 @@ trait Server {
}
object ServerBuilder {
- def apply() = new ServerBuilder[Any, Any]()
+ def apply() = new ServerBuilder[Any, Any]
def get() = apply()
val defaultChannelFactory =
@@ -54,48 +52,49 @@ object ServerBuilder {
}
/**
- * A handy Builder for constructing Servers (i.e., binding Services to a port).
+ * A configuration object that represents what shall be built.
*/
-case class ServerBuilder[Req, Rep](
- private val _codec: Option[Codec[Req, Rep]],
- private val _statsReceiver: Option[StatsReceiver],
- private val _name: Option[String],
- private val _sendBufferSize: Option[Int],
- private val _recvBufferSize: Option[Int],
- private val _bindTo: Option[SocketAddress],
- private val _logger: Option[Logger],
- private val _tls: Option[(String, String)],
- private val _startTls: Boolean,
- private val _channelFactory: Option[ReferenceCountedChannelFactory],
- private val _maxConcurrentRequests: Option[Int],
- private val _hostConnectionMaxIdleTime: Option[Duration],
- private val _requestTimeout: Option[Duration],
- private val _readTimeout: Option[Duration],
- private val _writeCompletionTimeout: Option[Duration],
- private val _traceReceiver: TraceReceiver)
+final case class ServerConfig[Req, Rep](
+ private val _codec: Option[Codec[Req, Rep]] = None,
+ private val _statsReceiver: Option[StatsReceiver] = None,
+ private val _name: Option[String] = None,
+ private val _sendBufferSize: Option[Int] = None,
+ private val _recvBufferSize: Option[Int] = None,
+ private val _bindTo: Option[SocketAddress] = None,
+ private val _logger: Option[Logger] = None,
+ private val _tls: Option[(String, String)] = None,
+ private val _startTls: Boolean = false,
+ private val _channelFactory: ReferenceCountedChannelFactory = ServerBuilder.defaultChannelFactory,
+ private val _maxConcurrentRequests: Option[Int] = None,
+ private val _hostConnectionMaxIdleTime: Option[Duration] = None,
+ private val _requestTimeout: Option[Duration] = None,
+ private val _readTimeout: Option[Duration] = None,
+ private val _writeCompletionTimeout: Option[Duration] = None,
+ private val _traceReceiver: TraceReceiver = new NullTraceReceiver)
{
- import ServerBuilder._
-
- def this() = this(
- None, // codec
- None, // statsReceiver
- None, // name
- None, // sendBufferSize
- None, // recvBufferSize
- None, // bindTo
- None, // logger
- None, // tls
- false, // startTls
- None, // channelFactory
- None, // maxConcurrentRequests
- None, // hostConnectionMaxIdleTime
- None, // requestTimeout
- None, // readTimeout
- None, // writeCompletionTimeout
- new NullTraceReceiver // traceReceiver
- )
-
- private[this] def options = Seq(
+ /**
+ * The Scala compiler errors if the case class members don't have underscores.
+ * Nevertheless, we want a friendly public API so we create delegators without
+ * underscores.
+ */
+ val codec = _codec
+ val statsReceiver = _statsReceiver
+ val name = _name
+ val sendBufferSize = _sendBufferSize
+ val recvBufferSize = _recvBufferSize
+ val bindTo = _bindTo
+ val logger = _logger
+ val tls = _tls
+ val startTls = _startTls
+ val channelFactory = _channelFactory
+ val maxConcurrentRequests = _maxConcurrentRequests
+ val hostConnectionMaxIdleTime = _hostConnectionMaxIdleTime
+ val requestTimeout = _requestTimeout
+ val readTimeout = _readTimeout
+ val writeCompletionTimeout = _writeCompletionTimeout
+ val traceReceiver = _traceReceiver
+
+ def toMap = Map(
"codec" -> _codec,
"statsReceiver" -> _statsReceiver,
"name" -> _name,
@@ -105,7 +104,7 @@ case class ServerBuilder[Req, Rep](
"logger" -> _logger,
"tls" -> _tls,
"startTls" -> Some(_startTls),
- "channelFactory" -> _channelFactory,
+ "channelFactory" -> Some(_channelFactory),
"maxConcurrentRequests" -> _maxConcurrentRequests,
"hostConnectionMaxIdleTime" -> _hostConnectionMaxIdleTime,
"requestTimeout" -> _requestTimeout,
@@ -114,59 +113,88 @@ case class ServerBuilder[Req, Rep](
"traceReceiver" -> Some(_traceReceiver)
)
- override def toString() = {
- "ServerBuilder(%s)".format(
- options flatMap {
- case (k, Some(v)) => Some("%s=%s".format(k, v))
- case _ => None
+ override def toString = {
+ "ServerConfig(%s)".format(
+ toMap flatMap {
+ case (k, Some(v)) =>
+ Some("%s=%s".format(k, v))
+ case _ =>
+ None
} mkString(", "))
}
+ def assertValid() {
+ _codec.getOrElse {
+ throw new IncompleteSpecification("No codec was specified")
+ }
+ _bindTo.getOrElse {
+ throw new IncompleteSpecification("No port was specified")
+ }
+ }
+}
+
+/**
+ * A handy Builder for constructing Servers (i.e., binding Services to a port).
+ * This class is subclassable. Override copy() and build() to do your own
+ * dirty work.
+ */
+class ServerBuilder[Req, Rep](val config: ServerConfig[Req, Rep]) {
+ import ServerBuilder._
+
+ def this() = this(new ServerConfig)
+
+ override def toString() = "ServerBuilder(%s)".format(config.toString)
+
+ protected def copy[Req1, Rep1](config: ServerConfig[Req1, Rep1]) =
+ new ServerBuilder(config)
+
def codec[Req1, Rep1](codec: Codec[Req1, Rep1]) =
- copy(_codec = Some(codec))
+ copy(config.copy(_codec = Some(codec)))
def reportTo(receiver: StatsReceiver) =
- copy(_statsReceiver = Some(receiver))
+ copy(config.copy(_statsReceiver = Some(receiver)))
- def name(value: String) = copy(_name = Some(value))
+ def name(value: String) =
+ copy(config.copy(_name = Some(value)))
- def sendBufferSize(value: Int) = copy(_sendBufferSize = Some(value))
- def recvBufferSize(value: Int) = copy(_recvBufferSize = Some(value))
+ def sendBufferSize(value: Int) =
+ copy(config.copy(_sendBufferSize = Some(value)))
+
+ def recvBufferSize(value: Int) =
+ copy(config.copy(_recvBufferSize = Some(value)))
def bindTo(address: SocketAddress) =
- copy(_bindTo = Some(address))
+ copy(config.copy(_bindTo = Some(address)))
def channelFactory(cf: ReferenceCountedChannelFactory) =
- copy(_channelFactory = Some(cf))
+ copy(config.copy(_channelFactory = cf))
- def logger(logger: Logger) = copy(_logger = Some(logger))
+ def logger(logger: Logger) =
+ copy(config.copy(_logger = Some(logger)))
def tls(certificatePath: String, keyPath: String) =
- copy(_tls = Some((certificatePath, keyPath)))
+ copy(config.copy(_tls = Some((certificatePath, keyPath))))
def startTls(value: Boolean) =
- copy(_startTls = true)
+ copy(config.copy(_startTls = true))
def maxConcurrentRequests(max: Int) =
- copy(_maxConcurrentRequests = Some(max))
+ copy(config.copy(_maxConcurrentRequests = Some(max)))
def hostConnectionMaxIdleTime(howlong: Duration) =
- copy(_hostConnectionMaxIdleTime = Some(howlong))
+ copy(config.copy(_hostConnectionMaxIdleTime = Some(howlong)))
def requestTimeout(howlong: Duration) =
- copy(_requestTimeout = Some(howlong))
+ copy(config.copy(_requestTimeout = Some(howlong)))
def readTimeout(howlong: Duration) =
- copy(_readTimeout = Some(howlong))
+ copy(config.copy(_readTimeout = Some(howlong)))
def writeCompletionTimeout(howlong: Duration) =
- copy(_writeCompletionTimeout = Some(howlong))
+ copy(config.copy(_writeCompletionTimeout = Some(howlong)))
def traceReceiver(receiver: TraceReceiver) =
- copy(_traceReceiver = receiver)
-
- private[this] def scopedStatsReceiver =
- _statsReceiver map { sr => _name map (sr.scope(_)) getOrElse sr }
+ copy(config.copy(_traceReceiver = receiver))
/**
* Construct the Server, given the provided Service.
@@ -179,23 +207,26 @@ case class ServerBuilder[Req, Rep](
* or supports transactions).
*/
def build(serviceFactory: () => Service[Req, Rep]): Server = {
- val codec = _codec.getOrElse {
- throw new IncompleteSpecification("No codec was specified")
- }
+ config.assertValid()
- val cf = _channelFactory getOrElse defaultChannelFactory
+ val scopedStatsReceiver =
+ config.statsReceiver map { sr => config.name map (sr.scope(_)) getOrElse sr }
+
+ val codec = config.codec.get
+
+ val cf = config.channelFactory
cf.acquire()
val bs = new ServerBootstrap(new ChannelFactoryToServerChannelFactory(cf))
bs.setOption("tcpNoDelay", true)
// bs.setOption("soLinger", 0) // XXX: (TODO)
bs.setOption("reuseAddress", true)
- _sendBufferSize foreach { s => bs.setOption("sendBufferSize", s) }
- _recvBufferSize foreach { s => bs.setOption("receiveBufferSize", s) }
+ config.sendBufferSize foreach { s => bs.setOption("sendBufferSize", s) }
+ config.recvBufferSize foreach { s => bs.setOption("receiveBufferSize", s) }
// TODO: we need something akin to a max queue depth.
val queueingChannelHandlerAndGauges =
- _maxConcurrentRequests map { maxConcurrentRequests =>
+ config.maxConcurrentRequests map { maxConcurrentRequests =>
val semaphore = new AsyncSemaphore(maxConcurrentRequests)
val gauges = scopedStatsReceiver.toList flatMap { sr =>
sr.addGauge("request_concurrency") {
@@ -208,10 +239,8 @@ case class ServerBuilder[Req, Rep](
(new ChannelSemaphoreHandler(semaphore), gauges)
}
- val queueingChannelHandler =
- queueingChannelHandlerAndGauges map { case (q, _) => q }
- val gauges =
- queueingChannelHandlerAndGauges.toList flatMap { case (_, g) => g }
+ val queueingChannelHandler = queueingChannelHandlerAndGauges map { case (q, _) => q }
+ val gauges = queueingChannelHandlerAndGauges.toList flatMap { case (_, g) => g }
trait ChannelHandle {
def drain(): Future[Unit]
@@ -229,9 +258,9 @@ case class ServerBuilder[Req, Rep](
def getPipeline = {
val pipeline = codec.serverPipelineFactory.getPipeline
- _logger foreach { logger =>
+ config.logger foreach { logger =>
pipeline.addFirst(
- "channelLogger", ChannelSnooper(_name getOrElse "server")(logger.info))
+ "channelLogger", ChannelSnooper(config.name getOrElse "server")(logger.info))
}
channelStatsHandler foreach { handler =>
@@ -244,32 +273,30 @@ case class ServerBuilder[Req, Rep](
// Note that the timeout is *after* request decoding. This
// prevents death from clients trying to DoS by slowly
// trickling in bytes to our (accumulating) codec.
- _readTimeout foreach { howlong =>
+ config.readTimeout foreach { howlong =>
val (timeoutValue, timeoutUnit) = howlong.inTimeUnit
pipeline.addLast(
"readTimeout",
new ReadTimeoutHandler(Timer.defaultNettyTimer, timeoutValue, timeoutUnit))
}
- _writeCompletionTimeout foreach { howlong =>
+ config.writeCompletionTimeout foreach { howlong =>
pipeline.addLast(
"writeCompletionTimeout",
new WriteCompletionTimeoutHandler(Timer.default, howlong))
}
// SSL comes first so that ChannelSnooper gets plaintext
- _tls foreach { case (certificatePath, keyPath) =>
+ config.tls foreach { case (certificatePath, keyPath) =>
val sslEngine = Ssl.server(certificatePath, keyPath).createSSLEngine()
sslEngine.setUseClientMode(false)
sslEngine.setEnableSessionCreation(true)
- pipeline.addFirst("ssl", new SslHandler(sslEngine, _startTls))
+ pipeline.addFirst("ssl", new SslHandler(sslEngine, config.startTls))
}
// Serialization keeps the codecs honest.
- pipeline.addLast(
- "requestSerializing",
- new ChannelSemaphoreHandler(new AsyncSemaphore(1)))
+ pipeline.addLast("requestSerializing", new ChannelSemaphoreHandler(new AsyncSemaphore(1)))
// Add this after the serialization to get an accurate request
// count.
@@ -300,20 +327,20 @@ case class ServerBuilder[Req, Rep](
val closingHandler = new ChannelClosingHandler
pipeline.addLast("closingHandler", closingHandler)
- _hostConnectionMaxIdleTime foreach { duration =>
+ config.hostConnectionMaxIdleTime foreach { duration =>
service = new ExpiringService(service, duration) {
override def didExpire() { closingHandler.close() }
}
}
- _requestTimeout foreach { duration =>
+ config.requestTimeout foreach { duration =>
service = (new TimeoutFilter(duration)) andThen service
}
// This has to go last (ie. first in the stack) so that
// protocol-specific trace support can override our generic
// one here.
- service = (new TracingFilter(_traceReceiver)) andThen service
+ service = (new TracingFilter(config.traceReceiver)) andThen service
// Register the channel so we can wait for them for a
// drain. We close the socket but wait for all handlers to
@@ -343,7 +370,7 @@ case class ServerBuilder[Req, Rep](
}
})
- val serverChannel = bs.bind(_bindTo.get)
+ val serverChannel = bs.bind(config.bindTo.get)
Timer.default.acquire()
new Server {
def close(timeout: Duration = Duration.MaxValue) = {
@@ -385,13 +412,7 @@ case class ServerBuilder[Req, Rep](
Timer.default.stop()
}
- override def toString = {
- "Server(%s)".format(
- options flatMap {
- case (k, Some(v)) => Some("%s=%s".format(k, v))
- case _ => None
- } mkString(", "))
- }
+ override def toString = "Server(%s)".format(config.toString)
}
}
}
Oops, something went wrong.

0 comments on commit 62f760f

Please sign in to comment.