Skip to content

Commit

Permalink
thrift: support for buffered transports, client only.
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusae committed Mar 21, 2011
1 parent b8ff969 commit 62d4961
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 82 deletions.
@@ -0,0 +1,31 @@
package com.twitter.finagle.thrift

import org.jboss.netty.channel.{Channel, ChannelHandlerContext}
import org.jboss.netty.buffer.ChannelBuffer
import org.jboss.netty.handler.codec.replay.{ReplayingDecoder, VoidEnum}

import org.apache.thrift.protocol.{TProtocolFactory, TProtocolUtil, TType}

class ThriftBufferCodec(protocolFactory: TProtocolFactory)
extends ReplayingDecoder[VoidEnum]
{
override def decode(
ctx: ChannelHandlerContext, channel: Channel,
buffer: ChannelBuffer, state: VoidEnum
) = {
val transport = new ChannelBufferToTransport(buffer)
val iprot = protocolFactory.getProtocol(transport)

val beginIndex = buffer.readerIndex
buffer.markReaderIndex()

iprot.readMessageBegin()
TProtocolUtil.skip(iprot, TType.STRUCT)
iprot.readMessageEnd()

val endIndex = buffer.readerIndex
buffer.resetReaderIndex()

buffer.readSlice(endIndex - beginIndex)
}
}
@@ -0,0 +1,29 @@
package com.twitter.finagle.thrift

import org.jboss.netty.channel.ChannelPipelineFactory
import org.apache.thrift.protocol.TProtocolFactory

import com.twitter.finagle.Codec

class ThriftClientBufferedCodec(protocolFactory: TProtocolFactory)
extends Codec[ThriftClientRequest, Array[Byte]]
{
private[this] val framedCodec = new ThriftClientFramedCodec

val clientPipelineFactory = {
val framedPipelineFactory = framedCodec.clientPipelineFactory

new ChannelPipelineFactory {
def getPipeline() = {
val pipeline = framedPipelineFactory.getPipeline
pipeline.replace(
"thriftFrameCodec", "thriftBufferCodec",
new ThriftBufferCodec(protocolFactory))
pipeline
}
}
}

val serverPipelineFactory = clientPipelineFactory
}

Expand Up @@ -6,20 +6,21 @@ import java.util.concurrent.CyclicBarrier

import org.specs.Specification

import org.apache.thrift.transport.{TServerSocket, TFramedTransport}
import org.apache.thrift.transport.{TServerSocket, TFramedTransport, TTransportFactory}
import org.apache.thrift.protocol.TBinaryProtocol
import org.apache.thrift.server.TSimpleServer
import org.apache.thrift.async.AsyncMethodCallback

import com.twitter.test.{B, AnException, SomeStruct}
import com.twitter.util.{RandomSocket, Promise, Return, Throw, Future}

import com.twitter.finagle.Codec
import com.twitter.finagle.builder.ClientBuilder

object FinagleClientThriftServerSpec extends Specification {
"finagle client vs. synchronous thrift server" should {
var somewayPromise = new Promise[Unit]
def makeServer(f: (Int, Int) => Int) = {
def makeServer(transportFactory: TTransportFactory)(f: (Int, Int) => Int) = {
val processor = new B.Iface {
def multiply(a: Int, b: Int): Int = f(a, b)
def add(a: Int, b: Int): Int = { throw new AnException }
Expand All @@ -40,7 +41,7 @@ object FinagleClientThriftServerSpec extends Specification {
val server = new TSimpleServer(
new B.Processor(processor),
serverSocketTransport,
new TFramedTransport.Factory(),
transportFactory,
new TBinaryProtocol.Factory()
)

Expand All @@ -60,90 +61,101 @@ object FinagleClientThriftServerSpec extends Specification {
thriftServerAddr
}

"talk to each other" in {
// TODO: interleave requests (to test seqids, etc.)

val thriftServerAddr = makeServer { (a, b) => a + b }

// ** Set up the client & query the server.
val service = ClientBuilder()
.hosts(Seq(thriftServerAddr))
.codec(ThriftClientFramedCodec())
.build()

val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())

val future = client.multiply(1, 2)
future() must be_==(3)
}

"handle exceptions" in {
val thriftServerAddr = makeServer { (a, b) => a + b }

// ** Set up the client & query the server.
val service = ClientBuilder()
.hosts(Seq(thriftServerAddr))
.codec(ThriftClientFramedCodec())
.build()

val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())

client.add(1, 2)() must throwA[AnException]
}

"handle void returns" in {
val thriftServerAddr = makeServer { (a, b) => a + b }

// ** Set up the client & query the server.
val service = ClientBuilder()
.hosts(Seq(thriftServerAddr))
.codec(ThriftClientFramedCodec())
.build()

val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())

client.add_one(1, 2)()
true must beTrue
def doit(transportFactory: TTransportFactory, codec: Codec[ThriftClientRequest, Array[Byte]]) {
"talk to each other" in {
// TODO: interleave requests (to test seqids, etc.)

val thriftServerAddr = makeServer(transportFactory) { (a, b) => a + b }

// ** Set up the client & query the server.
val service = ClientBuilder()
.hosts(Seq(thriftServerAddr))
.codec(codec)
.build()

val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())

val future = client.multiply(1, 2)
future() must be_==(3)
}

"handle exceptions" in {
val thriftServerAddr = makeServer(transportFactory) { (a, b) => a + b }

// ** Set up the client & query the server.
val service = ClientBuilder()
.hosts(Seq(thriftServerAddr))
.codec(codec)
.build()

val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())

client.add(1, 2)() must throwA[AnException]
}

"handle void returns" in {
val thriftServerAddr = makeServer(transportFactory) { (a, b) => a + b }

// ** Set up the client & query the server.
val service = ClientBuilder()
.hosts(Seq(thriftServerAddr))
.codec(codec)
.build()

val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())

client.add_one(1, 2)()
true must beTrue
}

// race condition..
"handle one-way calls" in {
val thriftServerAddr = makeServer(transportFactory) { (a, b) => a + b }

// ** Set up the client & query the server.
val service = ClientBuilder()
.hosts(Seq(thriftServerAddr))
.codec(codec)
.build()

val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())

somewayPromise.isDefined must beFalse
client.someway()() must beNull // returns
somewayPromise() must be_==(())
}

"talk to multiple servers" in {
val NumParties = 10
val barrier = new CyclicBarrier(NumParties)

val addrs = 0 until NumParties map { _ =>
makeServer(transportFactory) { (a, b) => barrier.await(); a + b }
}

// ** Set up the client & query the server.
val service = ClientBuilder()
.hosts(addrs)
.codec(codec)
.build()

val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())

{
val futures = 0 until NumParties map { _ => client.multiply(1, 2) }
val resolved = futures map(_())
resolved foreach { r => r must be_==(3) }
}
}
}

// race condition..
"handle one-way calls" in {
val thriftServerAddr = makeServer { (a, b) => a + b }

// ** Set up the client & query the server.
val service = ClientBuilder()
.hosts(Seq(thriftServerAddr))
.codec(ThriftClientFramedCodec())
.build()

val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())

somewayPromise.isDefined must beFalse
client.someway()() must beNull // returns
somewayPromise() must be_==(())
"framed transport" in {
doit(new TFramedTransport.Factory(), ThriftClientFramedCodec())
}

"talk to multiple servers" in {
val NumParties = 10
val barrier = new CyclicBarrier(NumParties)

val addrs = 0 until NumParties map { _ =>
makeServer { (a, b) => barrier.await(); a + b }
}

// ** Set up the client & query the server.
val service = ClientBuilder()
.hosts(addrs)
.codec(ThriftClientFramedCodec())
.build()

val client = new B.ServiceToClient(service, new TBinaryProtocol.Factory())

{
val futures = 0 until NumParties map { _ => client.multiply(1, 2) }
val resolved = futures map(_())
resolved foreach { r => r must be_==(3) }
}
"buffered transport" in {
doit(new TTransportFactory, new ThriftClientBufferedCodec(new TBinaryProtocol.Factory))
}
}
}

0 comments on commit 62d4961

Please sign in to comment.