Skip to content

Commit

Permalink
airframe-grpc: Propagate RPCException to clients (#2139)
Browse files Browse the repository at this point in the history
* Add RPCStatus.ofCodeName
* Add RPCStatus.toException
* propargate stack trace to RPC clients
* Check stacktrace propagation
* Change the key name
* Test suppressed stacktrace
* Fix status code list
  • Loading branch information
xerial committed Apr 27, 2022
1 parent a4b463e commit 7f21193
Show file tree
Hide file tree
Showing 9 changed files with 347 additions and 52 deletions.
Expand Up @@ -24,8 +24,9 @@ class StringUnapplyCodec[A](codec: Surface) extends MessageCodec[A] with LogSupp
override def pack(p: Packer, v: A): Unit = {
p.packString(v.toString)
}

override def unpack(u: Unpacker, v: MessageContext): Unit = {
val s = u.unpackString
val s = u.unpackValue.toString
TypeConverter.convertToCls(s, codec.rawType) match {
case Some(x) =>
v.setObject(x)
Expand Down
Expand Up @@ -60,10 +60,7 @@ case class GenericException(
}

object GenericException {
def fromThrowable(e: Throwable, seen: Set[Throwable] = Set.empty): GenericException = {
val exceptionClass = e.getClass.getName
val message = Option(e.getMessage).getOrElse(e.getClass.getSimpleName)

def extractStackTrace(e: Throwable): Seq[GenericStackTraceElement] = {
val stackTrace = for (x <- e.getStackTrace) yield {
GenericStackTraceElement(
className = x.getClassName,
Expand All @@ -72,7 +69,14 @@ object GenericException {
lineNumber = x.getLineNumber
)
}
stackTrace
}

def fromThrowable(e: Throwable, seen: Set[Throwable] = Set.empty): GenericException = {
val exceptionClass = e.getClass.getName
val message = Option(e.getMessage).getOrElse(e.getClass.getSimpleName)

val stackTrace = extractStackTrace(e)
val cause = Option(e.getCause).flatMap { ce =>
if (seen.contains(ce)) {
None
Expand Down
Expand Up @@ -15,7 +15,7 @@ package wvlet.airframe.http.grpc.internal

import io.grpc.{Metadata, Status, StatusException, StatusRuntimeException}
import wvlet.airframe.codec.MessageCodecException
import wvlet.airframe.http.{GrpcStatus, HttpServerException, HttpStatus}
import wvlet.airframe.http.{GrpcStatus, HttpServerException, HttpStatus, RPCException}
import wvlet.log.LogSupport

import java.lang.reflect.InvocationTargetException
Expand All @@ -26,7 +26,7 @@ import scala.concurrent.ExecutionException
*/
object GrpcException extends LogSupport {

private[grpc] val rpcErrorKey = Metadata.Key.of[String]("airframe_rpc_error", Metadata.ASCII_STRING_MARSHALLER)
private[grpc] val rpcErrorBodyKey = Metadata.Key.of[String]("airframe_rpc_error", Metadata.ASCII_STRING_MARSHALLER)

/**
* Convert an exception to gRPC-specific exception types
Expand Down Expand Up @@ -81,11 +81,28 @@ object GrpcException extends LogSupport {
if (e.message.nonEmpty) {
val m = e.message
val metadata = new Metadata()
metadata.put[String](rpcErrorKey, s"${m.toContentString}")
metadata.put[String](rpcErrorBodyKey, s"${m.toContentString}")
s.asRuntimeException(metadata)
} else {
s.asRuntimeException()
}
case e: RPCException =>
val grpcStatus = e.status.grpcStatus
val s = Status
.fromCodeValue(grpcStatus.code)
.withCause(e.cause.getOrElse(null))
.withDescription(e.getMessage)

val metadata = new Metadata()
try {
metadata.put[String](rpcErrorBodyKey, e.toJson)
} catch {
case ex: Throwable =>
// Failed to build JSON data.
// Just show warning so as not to block the RPC response
warn(s"Failed to serialize RPCException: ${e}", ex)
}
s.asRuntimeException(metadata)
case other =>
io.grpc.Status.INTERNAL
.withCause(other)
Expand Down
Expand Up @@ -16,11 +16,14 @@ package wvlet.airframe.http.grpc
import io.grpc.Status
import io.grpc.Status.Code
import io.grpc.StatusRuntimeException
import wvlet.airframe.http.Router
import wvlet.airframe.http.{RPCException, RPCStatus, Router}
import wvlet.airframe.http.grpc.GrpcErrorLogTest.DemoApiDebug
import wvlet.airframe.http.grpc.example.DemoApi.DemoApiClient
import wvlet.airframe.http.grpc.internal.GrpcException
import wvlet.airspec.AirSpec
import wvlet.log.{LogLevel, Logger}

import java.io.{PrintWriter, StringWriter}

object GrpcErrorHandlingTest extends AirSpec {

Expand All @@ -31,16 +34,86 @@ object GrpcErrorHandlingTest extends AirSpec {
.designWithChannel
}

test("handle error") { (client: DemoApiClient) =>
private def suppressLog(loggerName: String)(body: => Unit): Unit = {
val l = Logger(loggerName)
val previousLogLevel = l.getLogLevel
try {
// Suppress error logs
l.setLogLevel(LogLevel.OFF)
body
} finally {
l.setLogLevel(previousLogLevel)
}
}
test("exception test") { (client: DemoApiClient) =>
warn("Starting a gRPC error handling test")
val ex = intercept[StatusRuntimeException] {
client.error409Test
suppressLog("wvlet.airframe.http.grpc.internal") {

test("propagate HttpServerException") {
val ex = intercept[StatusRuntimeException] {
client.error409Test
}
ex.getMessage.contains("409") shouldBe true
ex.getStatus.isOk shouldBe false
ex.getStatus.getCode shouldBe Code.ABORTED
val trailers = Status.trailersFromThrowable(ex)
val rpcError = trailers.get[String](GrpcException.rpcErrorBodyKey)
rpcError.contains("test message") shouldBe true
}

test("propagate RPCException") {
val ex = intercept[StatusRuntimeException] {
client.rpcExceptionTest(false)
}
val trailers = Status.trailersFromThrowable(ex)
val rpcErrorJson = trailers.get[String](GrpcException.rpcErrorBodyKey)
val e = RPCException.fromJson(rpcErrorJson)

e.status shouldBe RPCStatus.SYNTAX_ERROR_U3
e.message shouldBe "test RPC exception"
e.cause shouldNotBe empty
e.appErrorCode shouldBe Some(11)
e.metadata shouldBe Map("retry" -> 0)

// Extract stack trace
val s = new StringWriter()
val out = new PrintWriter(s)
e.printStackTrace(out)
out.flush()

val stackTrace = s.toString
// Stack trace should contain two traces from the exception itself and its cause
stackTrace.contains("wvlet.airframe.http.RPCStatus.newException") shouldBe true
stackTrace.contains("wvlet.airframe.http.grpc.example.DemoApi.throwEx") shouldBe true
stackTrace.contains("wvlet.airframe.http.grpc.example.DemoApi.rpcExceptionTest") shouldBe true
}

test("suppress RPCException stacktrace") {
val ex = intercept[StatusRuntimeException] {
client.rpcExceptionTest(true)
}
val trailers = Status.trailersFromThrowable(ex)
val rpcErrorJson = trailers.get[String](GrpcException.rpcErrorBodyKey)
val e = RPCException.fromJson(rpcErrorJson)

e.status shouldBe RPCStatus.SYNTAX_ERROR_U3
e.message shouldBe "test RPC exception"
e.cause shouldBe empty
e.appErrorCode shouldBe Some(11)
e.metadata shouldBe Map("retry" -> 0)

// Extract stack trace
val s = new StringWriter()
val out = new PrintWriter(s)
e.printStackTrace(out)
out.flush()

val stackTrace = s.toString
// Stack trace should not have detailed information when RPCException.noStackTrace is used
stackTrace.contains("wvlet.airframe.http.RPCStatus.newException") shouldBe false
stackTrace.contains("wvlet.airframe.http.grpc.example.DemoApi.throwEx") shouldBe false
stackTrace.contains("wvlet.airframe.http.grpc.example.DemoApi.rpcExceptionTest") shouldBe false
}
}
ex.getMessage.contains("409") shouldBe true
ex.getStatus.isOk shouldBe false
ex.getStatus.getCode shouldBe Code.ABORTED
val trailers = Status.trailersFromThrowable(ex)
val rpcError = trailers.get[String](GrpcException.rpcErrorKey)
info(s"error trailer: ${rpcError}")
}
}
Expand Up @@ -20,7 +20,7 @@ import wvlet.airframe.codec.MessageCodecFactory
import wvlet.airframe.http.grpc.internal.GrpcServiceBuilder
import wvlet.airframe.http.grpc._
import wvlet.airframe.http.router.Route
import wvlet.airframe.http.{Http, HttpStatus, RPC, Router}
import wvlet.airframe.http.{Http, HttpStatus, RPC, RPCStatus, Router}
import wvlet.airframe.msgpack.spi.MsgPack
import wvlet.airframe.rx.{Rx, RxStream}
import wvlet.log.LogSupport
Expand Down Expand Up @@ -65,6 +65,27 @@ trait DemoApi extends LogSupport {
def error409Test: String = {
throw Http.serverException(HttpStatus.Conflict_409).withContent("test message")
}

private def throwEx = throw new IllegalArgumentException("syntax error")

def rpcExceptionTest(suppress: Boolean): String = {
try {
throwEx
""
} catch {
case e: Throwable =>
val ex = RPCStatus.SYNTAX_ERROR_U3.newException(
message = "test RPC exception",
cause = e,
appErrorCode = 11,
metadata = Map("retry" -> 0)
)
if (suppress) {
ex.noStackTrace
}
throw ex
}
}
}

object DemoApi {
Expand Down Expand Up @@ -118,6 +139,8 @@ object DemoApi {
GrpcServiceBuilder.buildMethodDescriptor(getRoute("returnUnit"), codecFactory)
private val errorTestMethodDescriptor =
GrpcServiceBuilder.buildMethodDescriptor(getRoute("error409Test"), codecFactory)
private val rpcExceptionTestMethodDescriptor =
GrpcServiceBuilder.buildMethodDescriptor(getRoute("rpcExceptionTest"), codecFactory)

def withEncoding(encoding: GrpcEncoding): DemoApiClient = {
this.copy(encoding = encoding)
Expand Down Expand Up @@ -209,6 +232,18 @@ object DemoApi {

resp.asInstanceOf[String]
}

def rpcExceptionTest(suppress: Boolean): String = {
val resp = ClientCalls
.blockingUnaryCall(
_channel,
rpcExceptionTestMethodDescriptor,
getCallOptions,
encode(Map("suppress" -> suppress))
)

resp.asInstanceOf[String]
}
}

}
89 changes: 74 additions & 15 deletions airframe-http/src/main/scala/wvlet/airframe/http/RPCException.scala
Expand Up @@ -13,27 +13,86 @@
*/
package wvlet.airframe.http

import wvlet.airframe.codec.{GenericException, GenericStackTraceElement, MessageCodec}

/**
* RPCException provides a backend-independent (e.g., Finagle or gRPC) RPC error reporting mechanism.
* RPCException provides a backend-independent (e.g., Finagle or gRPC) RPC error reporting mechanism. Create this
* exception with (RPCStatus code).toException(...) method.
*
* @param rpcError
* If necessary, we can add more standard error_details parameter like
* https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto
*/
class RPCException(
rpcError: RPCError
) extends Exception(rpcError.toString, rpcError.cause.getOrElse(null))

case class RPCError(
case class RPCException(
// RPC status
status: RPCStatus,
status: RPCStatus = RPCStatus.INTERNAL_ERROR_I0,
// Error message
message: String,
message: String = "",
// Cause of the exception
cause: Option[Throwable] = None,
// Custom data
// [optional] Application-specific status code
appErrorCode: Option[Int] = None,
// [optional] Application-specific metadata
metadata: Map[String, Any] = Map.empty
) extends Exception(s"[${status}] ${message}", cause.getOrElse(null)) {

private var _includeStackTrace: Boolean = true

/**
* Do not embed stacktrace and the cause objects in the RPC exception error response
*/
def noStackTrace: RPCException = {
_includeStackTrace = false
this
}

def toMessage: RPCErrorMessage = {
RPCErrorMessage(
code = status.code,
codeName = status.name,
message = message,
stackTrace = if (_includeStackTrace) Some(GenericException.extractStackTrace(this)) else None,
cause = if (_includeStackTrace) cause else None,
appErrorCode = appErrorCode,
metadata = metadata
)
}

def toJson: String = {
MessageCodec.of[RPCErrorMessage].toJson(toMessage)
}
}

/**
* A model class for RPC error message body. This message will be embedded to HTTP response body or gRPC trailer.
*
* We need this class to avoid directly serde RPCException classes with airframe-codec, so that we can properly
* propagate the exact stack trace to the client.
*/
case class RPCErrorMessage(
code: Int = RPCStatus.UNKNOWN_I1.code,
codeName: String = RPCStatus.UNKNOWN_I1.name,
message: String = "",
stackTrace: Option[Seq[GenericStackTraceElement]] = None,
cause: Option[Throwable] = None,
appErrorCode: Option[Int] = None,
metadata: Map[String, Any] = Map.empty
) {
override def toString: String = s"[${status}] ${message}"
def toException: RPCException = new RPCException(this)
def withMessage(newMessage: String): RPCError = this.copy(message = newMessage)
def withMetadata(newMetadata: Map[String, Any]): RPCError = this.copy(metadata = newMetadata)
)

object RPCException {
def fromJson(json: String): RPCException = {
val codec = MessageCodec.of[RPCErrorMessage]
val m = codec.fromJson(json)
val ex = new RPCException(
status = RPCStatus.ofCode(m.code),
message = m.message,
cause = m.cause,
appErrorCode = m.appErrorCode,
metadata = m.metadata
)
// Recover the original stack trace
m.stackTrace.foreach { x =>
ex.setStackTrace(x.map(_.toJavaStackTraceElement).toArray)
}
ex
}
}

0 comments on commit 7f21193

Please sign in to comment.