Skip to content

Commit

Permalink
airframe-grpc: Fix thread local storage (#2373)
Browse files Browse the repository at this point in the history
To properly access thread-local storage in gRPC, the code needs to be wrapped inside the Listener body.
  • Loading branch information
xerial committed Aug 18, 2022
1 parent 78fb568 commit 216a309
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 17 deletions.
Expand Up @@ -28,7 +28,8 @@ object GrpcContext {
*
* @return
*/
def current: Option[GrpcContext] = Option(contextKey.get())
def current: Option[GrpcContext] = Option(contextKey.get())

private[grpc] def currentEncoding = current.map(_.encoding).getOrElse(RPCEncoding.MsgPack)

private[grpc] val KEY_ACCEPT = Metadata.Key.of("accept", Metadata.ASCII_STRING_MARSHALLER)
Expand Down
Expand Up @@ -15,7 +15,7 @@ package wvlet.airframe.http.grpc

import io.grpc.ServerServiceDefinition
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder
import wvlet.airframe.http.grpc.internal.{GrpcResponseHeaderInterceptor, ContextTrackInterceptor, GrpcRequestLogger}
import wvlet.airframe.http.grpc.internal.{GrpcContextTrackInterceptor, GrpcRequestLogger, GrpcResponseHeaderInterceptor}
import wvlet.log.LogSupport

import java.util.concurrent.ExecutorService
Expand Down Expand Up @@ -46,7 +46,7 @@ case class GrpcService(
// Add an interceptor for setting content-type response header
serverBuilder.intercept(GrpcResponseHeaderInterceptor)
// Add an interceptor for remembering GrpcContext. This must happen at the root level
serverBuilder.intercept(ContextTrackInterceptor)
serverBuilder.intercept(GrpcContextTrackInterceptor)

val customServerBuilder = config.serverInitializer(serverBuilder)
val server = new GrpcServer(this, customServerBuilder.build())
Expand Down
Expand Up @@ -14,28 +14,87 @@
package wvlet.airframe.http.grpc.internal

import io.grpc._
import wvlet.airframe.http.{Compat, RPCContext}
import wvlet.airframe.http.grpc.GrpcContext
import wvlet.log.LogSupport

/**
* A server request interceptor to set GrpcContext to the thread-local storage
*/
private[grpc] object ContextTrackInterceptor extends ServerInterceptor with LogSupport {
private[grpc] object GrpcContextTrackInterceptor extends ServerInterceptor with LogSupport {
override def interceptCall[ReqT, RespT](
call: ServerCall[ReqT, RespT],
headers: Metadata,
next: ServerCallHandler[ReqT, RespT]
): ServerCall.Listener[ReqT] = {
// Tell airframe-http about the thread-local RPC context

// Wrap the current context
val rpcContext = GrpcContext(Option(call.getAuthority), call.getAttributes, headers, call.getMethodDescriptor)
wvlet.airframe.http.Compat.attachRPCContext(rpcContext)

// Create a new context that conveys GrpcContext object.
val newContext = Context
.current().withValue(
GrpcContext.contextKey,
rpcContext
)
Contexts.interceptCall(newContext, call, headers, next)

val previous = newContext.attach()
val prevContext = Compat.attachRPCContext(rpcContext)
try {
new WrappedServerCallListener[(RPCContext, Context), ReqT](
onInit = {
(Compat.attachRPCContext(rpcContext), newContext.attach())
},
onDetach = { case (previousRpcContext: RPCContext, ctx: Context) =>
Compat.detachRPCContext(previousRpcContext)
newContext.detach(ctx)
},
next.startCall(call, headers)
)
} finally {
Compat.detachRPCContext(prevContext)
newContext.detach(previous)
}
}
}

/**
*/
private[grpc] class WrappedServerCallListener[A, ReqT](
onInit: => A,
onDetach: A => Unit,
delegate: ServerCall.Listener[ReqT]
) extends ForwardingServerCallListener.SimpleForwardingServerCallListener[ReqT](delegate) {
private def wrap(body: => Unit): Unit = {
val previous = onInit
try {
body
} finally {
onDetach(previous)
}
}

override def onMessage(message: ReqT): Unit = {
wrap {
delegate.onMessage(message)
}
}

override def onHalfClose(): Unit = {
wrap {
delegate.onHalfClose()
}
}

override def onComplete(): Unit = {
wrap {
delegate.onComplete()
}
}

override def onReady(): Unit = {
wrap {
delegate.onReady()
}
}
}
Expand Up @@ -20,6 +20,7 @@ import wvlet.airframe.http.grpc.example.DemoApi.DemoApiClient
import wvlet.airspec.AirSpec

import java.util.concurrent.Executor
import scala.collection.parallel.ParSeq

object GrpcContextTest extends AirSpec {

Expand All @@ -28,12 +29,14 @@ object GrpcContextTest extends AirSpec {
test("thread local context") { (client: DemoApiClient) =>
test("get context") {
val ret = client.getContext
info(ret)
debug(ret)
}

test("get context from RPCContext") {
val ret = client.getRPCContext
ret shouldBe Some(DemoApi.demoClientId)
for (i <- ParSeq(1 to 10)) {
val ret = client.getRPCContext
ret shouldBe Some(DemoApi.demoClientId)
}
}

test("get http request from RPCContext") {
Expand Down
Expand Up @@ -14,11 +14,20 @@
package wvlet.airframe.http.grpc.example

import io.grpc.stub.{AbstractBlockingStub, ClientCallStreamObserver, ClientCalls}
import io.grpc.{CallOptions, Channel, Contexts, Metadata, ServerCall, ServerCallHandler, ServerInterceptor}
import io.grpc.{
CallOptions,
Channel,
Contexts,
ForwardingServerCallListener,
Metadata,
ServerCall,
ServerCallHandler,
ServerInterceptor
}
import wvlet.airframe.Design
import wvlet.airframe.codec.MessageCodecFactory
import wvlet.airframe.http.HttpMessage.Request
import wvlet.airframe.http.grpc.internal.GrpcServiceBuilder
import wvlet.airframe.http.grpc.internal.{GrpcServiceBuilder, WrappedServerCallListener}
import wvlet.airframe.http.grpc._
import wvlet.airframe.http.router.Route
import wvlet.airframe.http.{Http, HttpStatus, RPC, RPCContext, RPCEncoding, RPCStatus, Router}
Expand Down Expand Up @@ -109,9 +118,17 @@ object DemoApi extends LogSupport {
headers: Metadata,
next: ServerCallHandler[ReqT, RespT]
): ServerCall.Listener[ReqT] = {
val ctx = RPCContext.current
ctx.setThreadLocal("client_id", demoClientId)
next.startCall(call, headers)

new WrappedServerCallListener[Unit, ReqT](
onInit = {
val ctx = RPCContext.current
ctx.setThreadLocal("client_id", demoClientId)
},
onDetach = { _ =>
// do nothing
},
next.startCall(call, headers)
)
}
}

Expand Down
Expand Up @@ -58,4 +58,5 @@ private object Compat extends CompatApi {

override def currentRPCContext: RPCContext = ???
override def attachRPCContext(context: RPCContext): RPCContext = ???
override def detachRPCContext(previous: RPCContext): Unit = ???
}
Expand Up @@ -59,4 +59,5 @@ object Compat extends CompatApi {

override def currentRPCContext: RPCContext = LocalRPCContext.current
override def attachRPCContext(context: RPCContext): RPCContext = LocalRPCContext.attach(context)
override def detachRPCContext(previous: RPCContext): Unit = LocalRPCContext.detach(previous)
}
Expand Up @@ -17,7 +17,7 @@ import wvlet.airframe.http.{RPCContext, EmptyRPCContext}

object LocalRPCContext {
private val localContext = new ThreadLocal[RPCContext]()
private val rootContext = new EmptyRPCContext()
private val rootContext = EmptyRPCContext

def current: RPCContext = {
Option(localContext.get()).getOrElse(rootContext)
Expand All @@ -31,4 +31,12 @@ object LocalRPCContext {
localContext.set(newContext)
prev
}
def detach(previousContext: RPCContext): Unit = {
if (previousContext != rootContext) {
localContext.set(previousContext)
} else {
// Avoid preserving the root thread information in the TLS
localContext.set(null)
}
}
}
Expand Up @@ -29,4 +29,5 @@ trait CompatApi {

def currentRPCContext: RPCContext
def attachRPCContext(context: RPCContext): RPCContext
def detachRPCContext(previous: RPCContext): Unit
}
Expand Up @@ -53,7 +53,7 @@ trait RPCContext {
/**
* An empty RPCContext
*/
class EmptyRPCContext extends RPCContext {
object EmptyRPCContext extends RPCContext {
override def setThreadLocal[A](key: String, value: A): Unit = {
// no-op
}
Expand Down

0 comments on commit 216a309

Please sign in to comment.