/
ThriftServerCodec.scala
74 lines (64 loc) · 2.61 KB
/
ThriftServerCodec.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
package com.twitter.finagle.thrift
import org.apache.thrift.TApplicationException
import org.apache.thrift.protocol.{TMessageType, TBinaryProtocol}
import org.jboss.netty.buffer.{ChannelBuffer, ChannelBuffers}
import org.jboss.netty.channel._
import org.jboss.netty.handler.codec.replay.{ReplayingDecoder, VoidEnum}
import java.util.logging.{Logger, Level}
/*
* Translate ThriftReplys to wire representation
*/
private[thrift] class ThriftServerEncoder extends SimpleChannelDownstreamHandler {
protected val protocolFactory = new TBinaryProtocol.Factory(true, true)
override def writeRequested(ctx: ChannelHandlerContext, e: MessageEvent) =
e.getMessage match {
case reply@ThriftReply(response, call) =>
val buffer = ChannelBuffers.dynamicBuffer()
val transport = new ChannelBufferToTransport(buffer)
val protocol = protocolFactory.getProtocol(transport)
call.writeReply(call.seqid, protocol, response)
Channels.write(ctx, Channels.succeededFuture(e.getChannel()), buffer, e.getRemoteAddress)
case _ =>
Channels.fireExceptionCaught(ctx, new IllegalArgumentException)
}
}
/**
* Translate wire representation to ThriftCalls
*/
private[thrift] class ThriftServerDecoder extends ReplayingDecoder[VoidEnum] {
private[this] val logger = Logger.getLogger(getClass.getName)
private[this] val protocolFactory = new TBinaryProtocol.Factory(true, true)
def decodeThriftCall(ctx: ChannelHandlerContext, channel: Channel,
buffer: ChannelBuffer):Object = {
val transport = new ChannelBufferToTransport(buffer)
val protocol = protocolFactory.getProtocol(transport)
val message = protocol.readMessageBegin()
message.`type` match {
case TMessageType.CALL =>
try {
val factory = ThriftTypes(message.name)
val request = factory.newInstance(message.seqid)
request.readRequestArgs(protocol)
request.asInstanceOf[AnyRef]
} catch {
// Pass through invalid message exceptions, etc.
case e: TApplicationException =>
logger.log(Level.FINE, e.getMessage, e)
null
}
case _ =>
// We can't respond with an error because we're in a replaying codec.
null
}
}
override def decode(ctx: ChannelHandlerContext,
channel: Channel,
buffer: ChannelBuffer,
state: VoidEnum) =
// Thrift incorrectly assumes a read of zero bytes is an error, so treat
// empty buffers as no-ops.
if (buffer.readable)
decodeThriftCall(ctx, channel, buffer)
else
null
}