Permalink
Browse files

thrift: support for buffered transports, client only.

  • Loading branch information...
1 parent b8ff969 commit 62d4961f6908557b772e993c0bb6e940ac1509b1 @mariusae mariusae committed Mar 21, 2011
@@ -0,0 +1,31 @@
+package com.twitter.finagle.thrift
+
+import org.jboss.netty.channel.{Channel, ChannelHandlerContext}
+import org.jboss.netty.buffer.ChannelBuffer
+import org.jboss.netty.handler.codec.replay.{ReplayingDecoder, VoidEnum}
+
+import org.apache.thrift.protocol.{TProtocolFactory, TProtocolUtil, TType}
+
+class ThriftBufferCodec(protocolFactory: TProtocolFactory)
+ extends ReplayingDecoder[VoidEnum]
+{
+ override def decode(
+ ctx: ChannelHandlerContext, channel: Channel,
+ buffer: ChannelBuffer, state: VoidEnum
+ ) = {
+ val transport = new ChannelBufferToTransport(buffer)
+ val iprot = protocolFactory.getProtocol(transport)
+
+ val beginIndex = buffer.readerIndex
+ buffer.markReaderIndex()
+
+ iprot.readMessageBegin()
+ TProtocolUtil.skip(iprot, TType.STRUCT)
+ iprot.readMessageEnd()
+
+ val endIndex = buffer.readerIndex
+ buffer.resetReaderIndex()
+
+ buffer.readSlice(endIndex - beginIndex)
+ }
+}
@@ -0,0 +1,29 @@
+package com.twitter.finagle.thrift
+
+import org.jboss.netty.channel.ChannelPipelineFactory
+import org.apache.thrift.protocol.TProtocolFactory
+
+import com.twitter.finagle.Codec
+
+class ThriftClientBufferedCodec(protocolFactory: TProtocolFactory)
+ extends Codec[ThriftClientRequest, Array[Byte]]
+{
+ private[this] val framedCodec = new ThriftClientFramedCodec
+
+ val clientPipelineFactory = {
+ val framedPipelineFactory = framedCodec.clientPipelineFactory
+
+ new ChannelPipelineFactory {
+ def getPipeline() = {
+ val pipeline = framedPipelineFactory.getPipeline
+ pipeline.replace(
+ "thriftFrameCodec", "thriftBufferCodec",
+ new ThriftBufferCodec(protocolFactory))
+ pipeline
+ }
+ }
+ }
+
+ val serverPipelineFactory = clientPipelineFactory
+}
+
@@ -6,20 +6,21 @@ import java.util.concurrent.CyclicBarrier
import org.specs.Specification
-import org.apache.thrift.transport.{TServerSocket, TFramedTransport}
+import org.apache.thrift.transport.{TServerSocket, TFramedTransport, TTransportFactory}
import org.apache.thrift.protocol.TBinaryProtocol
import org.apache.thrift.server.TSimpleServer
import org.apache.thrift.async.AsyncMethodCallback
import com.twitter.test.{B, AnException, SomeStruct}
import com.twitter.util.{RandomSocket, Promise, Return, Throw, Future}
+import com.twitter.finagle.Codec
import com.twitter.finagle.builder.ClientBuilder
object FinagleClientThriftServerSpec extends Specification {
"finagle client vs. synchronous thrift server" should {
var somewayPromise = new Promise[Unit]
- def makeServer(f: (Int, Int) => Int) = {
+ def makeServer(transportFactory: TTransportFactory)(f: (Int, Int) => Int) = {
val processor = new B.Iface {
def multiply(a: Int, b: Int): Int = f(a, b)
def add(a: Int, b: Int): Int = { throw new AnException }
@@ -40,7 +41,7 @@ object FinagleClientThriftServerSpec extends Specification {
val server = new TSimpleServer(
new B.Processor(processor),
serverSocketTransport,
- new TFramedTransport.Factory(),
+ transportFactory,
new TBinaryProtocol.Factory()
)
@@ -60,90 +61,101 @@ object FinagleClientThriftServerSpec extends Specification {
thriftServerAddr
}
- "talk to each other" in {
- // TODO: interleave requests (to test seqids, etc.)
- val thriftServerAddr = makeServer { (a, b) => a + b }
-
- // ** Set up the client & query the server.
- val service = ClientBuilder()
- .hosts(Seq(thriftServerAddr))
- .codec(ThriftClientFramedCodec())
- .build()
-
- val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())
-
- val future = client.multiply(1, 2)
- future() must be_==(3)
- }
-
- "handle exceptions" in {
- val thriftServerAddr = makeServer { (a, b) => a + b }
-
- // ** Set up the client & query the server.
- val service = ClientBuilder()
- .hosts(Seq(thriftServerAddr))
- .codec(ThriftClientFramedCodec())
- .build()
-
- val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())
-
- client.add(1, 2)() must throwA[AnException]
- }
-
- "handle void returns" in {
- val thriftServerAddr = makeServer { (a, b) => a + b }
-
- // ** Set up the client & query the server.
- val service = ClientBuilder()
- .hosts(Seq(thriftServerAddr))
- .codec(ThriftClientFramedCodec())
- .build()
-
- val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())
-
- client.add_one(1, 2)()
- true must beTrue
+ def doit(transportFactory: TTransportFactory, codec: Codec[ThriftClientRequest, Array[Byte]]) {
+ "talk to each other" in {
+ // TODO: interleave requests (to test seqids, etc.)
+
+ val thriftServerAddr = makeServer(transportFactory) { (a, b) => a + b }
+
+ // ** Set up the client & query the server.
+ val service = ClientBuilder()
+ .hosts(Seq(thriftServerAddr))
+ .codec(codec)
+ .build()
+
+ val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())
+
+ val future = client.multiply(1, 2)
+ future() must be_==(3)
+ }
+
+ "handle exceptions" in {
+ val thriftServerAddr = makeServer(transportFactory) { (a, b) => a + b }
+
+ // ** Set up the client & query the server.
+ val service = ClientBuilder()
+ .hosts(Seq(thriftServerAddr))
+ .codec(codec)
+ .build()
+
+ val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())
+
+ client.add(1, 2)() must throwA[AnException]
+ }
+
+ "handle void returns" in {
+ val thriftServerAddr = makeServer(transportFactory) { (a, b) => a + b }
+
+ // ** Set up the client & query the server.
+ val service = ClientBuilder()
+ .hosts(Seq(thriftServerAddr))
+ .codec(codec)
+ .build()
+
+ val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())
+
+ client.add_one(1, 2)()
+ true must beTrue
+ }
+
+ // race condition..
+ "handle one-way calls" in {
+ val thriftServerAddr = makeServer(transportFactory) { (a, b) => a + b }
+
+ // ** Set up the client & query the server.
+ val service = ClientBuilder()
+ .hosts(Seq(thriftServerAddr))
+ .codec(codec)
+ .build()
+
+ val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())
+
+ somewayPromise.isDefined must beFalse
+ client.someway()() must beNull // returns
+ somewayPromise() must be_==(())
+ }
+
+ "talk to multiple servers" in {
+ val NumParties = 10
+ val barrier = new CyclicBarrier(NumParties)
+
+ val addrs = 0 until NumParties map { _ =>
+ makeServer(transportFactory) { (a, b) => barrier.await(); a + b }
+ }
+
+ // ** Set up the client & query the server.
+ val service = ClientBuilder()
+ .hosts(addrs)
+ .codec(codec)
+ .build()
+
+ val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())
+
+ {
+ val futures = 0 until NumParties map { _ => client.multiply(1, 2) }
+ val resolved = futures map(_())
+ resolved foreach { r => r must be_==(3) }
+ }
+ }
}
- // race condition..
- "handle one-way calls" in {
- val thriftServerAddr = makeServer { (a, b) => a + b }
-
- // ** Set up the client & query the server.
- val service = ClientBuilder()
- .hosts(Seq(thriftServerAddr))
- .codec(ThriftClientFramedCodec())
- .build()
-
- val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())
-
- somewayPromise.isDefined must beFalse
- client.someway()() must beNull // returns
- somewayPromise() must be_==(())
+ "framed transport" in {
+ doit(new TFramedTransport.Factory(), ThriftClientFramedCodec())
}
- "talk to multiple servers" in {
- val NumParties = 10
- val barrier = new CyclicBarrier(NumParties)
-
- val addrs = 0 until NumParties map { _ =>
- makeServer { (a, b) => barrier.await(); a + b }
- }
-
- // ** Set up the client & query the server.
- val service = ClientBuilder()
- .hosts(addrs)
- .codec(ThriftClientFramedCodec())
- .build()
-
- val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())
-
- {
- val futures = 0 until NumParties map { _ => client.multiply(1, 2) }
- val resolved = futures map(_())
- resolved foreach { r => r must be_==(3) }
- }
+ "buffered transport" in {
+ doit(new TTransportFactory, new ThriftClientBufferedCodec(new TBinaryProtocol.Factory))
}
}
}

0 comments on commit 62d4961

Please sign in to comment.