Skip to content
Browse files

[split] scrooge: use a TReusableMemoryTransport in finagle services

https://jira.twitter.biz/browse/PASSBIRD-658

Add a variant of Thrift's TMemoryBuffer ??? TReusableMemoryTransport.
It is a reimplementation of TMemoryBuffer with an added method
`reset()` that allows us to reuse this transport.

RB_ID=263349
  • Loading branch information...
1 parent 249aa5d commit c2bb897b5201f4c694a6245f5bfe626c52b1b684 @kevinoliver kevinoliver committed with CI
View
71 scrooge-core/src/main/scala/com/twitter/scrooge/TReusableMemoryTransport.scala
@@ -0,0 +1,71 @@
+package com.twitter.scrooge
+
+import org.apache.thrift.transport.TTransport
+import org.apache.thrift.TByteArrayOutputStream
+
+object TReusableMemoryTransport {
+
+ def apply(initialSize: Int = 512): TReusableMemoryTransport = {
+ new TReusableMemoryTransport(new TByteArrayOutputStream(initialSize))
+ }
+
+}
+
+/**
+ * A version of TMemoryTransport that allows for reuse in order to minimize
+ * object allocations.
+ */
+class TReusableMemoryTransport(
+ baos: TByteArrayOutputStream)
+ extends TTransport
+{
+
+ private[this] var readPos = 0
+
+ /**
+ * Resets both reads and writes.
+ */
+ def reset() {
+ baos.reset()
+ readPos = 0
+ }
+
+ def numWrittenBytes: Int = baos.len()
+
+ // Here for drop-in api compatibility with TMemoryBuffer
+ def length(): Int = numWrittenBytes
+
+ // Here for drop-in api compatibility with TMemoryBuffer
+ def getArray(): Array[Byte] = baos.get()
+
+ /**
+ * Total bytes currently allowed in the struct.
+ *
+ * Note: more writes beyond this length will cause the struct to grow.
+ */
+ def currentCapacity: Int = baos.get().length
+
+ override def isOpen: Boolean = true
+
+ override def close() { }
+
+ override def open() { }
+
+ override def write(from: Array[Byte], off: Int, len: Int) {
+ baos.write(from, off, len)
+ }
+
+ override def read(into: Array[Byte], off: Int, len: Int): Int = {
+ val bytesToRead = if (len > baos.len() - readPos) {
+ baos.len() - readPos
+ } else {
+ len
+ }
+ if (bytesToRead > 0) {
+ System.arraycopy(baos.get(), readPos, into, off, bytesToRead)
+ readPos += bytesToRead
+ }
+ bytesToRead
+ }
+
+}
View
56 scrooge-generator/src/main/resources/scalagen/finagleService.scala
@@ -1,13 +1,13 @@
package {{package}}
import com.twitter.finagle.{Service => FinagleService}
-import com.twitter.scrooge.ThriftStruct
+import com.twitter.scrooge.{ThriftStruct, TReusableMemoryTransport}
import com.twitter.util.Future
import java.nio.ByteBuffer
import java.util.Arrays
import org.apache.thrift.protocol._
import org.apache.thrift.TApplicationException
-import org.apache.thrift.transport.{TMemoryBuffer, TMemoryInputTransport}
+import org.apache.thrift.transport.TMemoryInputTransport
import scala.collection.mutable.{
ArrayBuffer => mutable$ArrayBuffer, HashMap => mutable$HashMap}
import scala.collection.{Map, Set}
@@ -21,6 +21,22 @@ class {{ServiceName}}$FinagleService(
import {{ServiceName}}._
{{^hasParent}}
+ private[this] val tlReusableBuffer = new ThreadLocal[TReusableMemoryTransport] {
+ override def initialValue() = TReusableMemoryTransport(512)
+ }
+
+ private[this] def reusableBuffer: TReusableMemoryTransport = {
+ val buf = tlReusableBuffer.get()
+ buf.reset()
+ buf
+ }
+
+ private[this] def resetBuffer(trans: TReusableMemoryTransport, maxCapacity: Int = 4096) {
+ if (trans.currentCapacity > maxCapacity) {
+ tlReusableBuffer.remove()
+ }
+ }
+
protected val functionMap = new mutable$HashMap[String, (TProtocol, Int) => Future[Array[Byte]]]()
protected def addFunction(name: String, f: (TProtocol, Int) => Future[Array[Byte]]) {
@@ -30,14 +46,18 @@ class {{ServiceName}}$FinagleService(
protected def exception(name: String, seqid: Int, code: Int, message: String): Future[Array[Byte]] = {
try {
val x = new TApplicationException(code, message)
- val memoryBuffer = new TMemoryBuffer(512)
- val oprot = protocolFactory.getProtocol(memoryBuffer)
-
- oprot.writeMessageBegin(new TMessage(name, TMessageType.EXCEPTION, seqid))
- x.write(oprot)
- oprot.writeMessageEnd()
- oprot.getTransport().flush()
- Future.value(Arrays.copyOfRange(memoryBuffer.getArray(), 0, memoryBuffer.length()))
+ val memoryBuffer = reusableBuffer
+ try {
+ val oprot = protocolFactory.getProtocol(memoryBuffer)
+
+ oprot.writeMessageBegin(new TMessage(name, TMessageType.EXCEPTION, seqid))
+ x.write(oprot)
+ oprot.writeMessageEnd()
+ oprot.getTransport().flush()
+ Future.value(Arrays.copyOfRange(memoryBuffer.getArray(), 0, memoryBuffer.length()))
+ } finally {
+ resetBuffer(memoryBuffer)
+ }
} catch {
case e: Exception => Future.exception(e)
}
@@ -45,14 +65,18 @@ class {{ServiceName}}$FinagleService(
protected def reply(name: String, seqid: Int, result: ThriftStruct): Future[Array[Byte]] = {
try {
- val memoryBuffer = new TMemoryBuffer(512)
- val oprot = protocolFactory.getProtocol(memoryBuffer)
+ val memoryBuffer = reusableBuffer
+ try {
+ val oprot = protocolFactory.getProtocol(memoryBuffer)
- oprot.writeMessageBegin(new TMessage(name, TMessageType.REPLY, seqid))
- result.write(oprot)
- oprot.writeMessageEnd()
+ oprot.writeMessageBegin(new TMessage(name, TMessageType.REPLY, seqid))
+ result.write(oprot)
+ oprot.writeMessageEnd()
- Future.value(Arrays.copyOfRange(memoryBuffer.getArray(), 0, memoryBuffer.length()))
+ Future.value(Arrays.copyOfRange(memoryBuffer.getArray(), 0, memoryBuffer.length()))
+ } finally {
+ resetBuffer(memoryBuffer)
+ }
} catch {
case e: Exception => Future.exception(e)
}
View
33 scrooge-generator/src/test/scala/com/twitter/scrooge/TReusableMemoryTransportSpec.scala
@@ -0,0 +1,33 @@
+package com.twitter.scrooge
+
+import com.twitter.scrooge.testutil.Spec
+
+class TReusableMemoryTransportSpec extends Spec {
+
+ "is reusable" in {
+ val cap = 10
+ val trans = TReusableMemoryTransport(cap)
+ trans.currentCapacity must be(cap)
+ var bytesRead = trans.read(new Array[Byte](1), 0, 999)
+ bytesRead must be(0)
+
+ val stringInBytes = "abcde".getBytes("UTF-8")
+ trans.write(stringInBytes)
+ trans.numWrittenBytes must be(5)
+
+ val read = new Array[Byte](100)
+ bytesRead = trans.read(read, 0, 999)
+ bytesRead must be(5)
+ read.take(bytesRead).mkString must be(stringInBytes.mkString)
+
+ trans.reset()
+
+ trans.numWrittenBytes must be(0)
+ bytesRead = trans.read(new Array[Byte](1), 0, 999)
+ bytesRead must be(0)
+
+ trans.write(stringInBytes)
+ trans.numWrittenBytes must be(5)
+ }
+
+}

0 comments on commit c2bb897

Please sign in to comment.
Something went wrong with that request. Please try again.