diff --git a/contrib/grpc-contrib/src/main/java/com/salesforce/grpc/contrib/interceptor/TransmitUnexpectedExceptionInterceptor.java b/contrib/grpc-contrib/src/main/java/com/salesforce/grpc/contrib/interceptor/TransmitUnexpectedExceptionInterceptor.java new file mode 100644 index 00000000..26cc9575 --- /dev/null +++ b/contrib/grpc-contrib/src/main/java/com/salesforce/grpc/contrib/interceptor/TransmitUnexpectedExceptionInterceptor.java @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2019, Salesforce.com, Inc. + * All rights reserved. + * Licensed under the BSD 3-Clause license. + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + */ + +package com.salesforce.grpc.contrib.interceptor; + +import io.grpc.*; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; + +/** + * A class that intercepts uncaught exceptions of all types and handles them by closing the {@link ServerCall}, and + * transmitting the exception's description and stack trace to the client. This class is a complement to gRPC's + * {@code TransmitStatusRuntimeExceptionInterceptor}. + * + *

Without this interceptor, gRPC will strip all details and close the {@link ServerCall} with + * a generic {@link Status#UNKNOWN} code. + * + *

Security warning: the exception description and stack trace may contain sensitive server-side + * state information, and generally should not be sent to clients. Only install this interceptor + * if all clients are trusted. + */ +// Heavily inspired by https://github.com/saturnism/grpc-java-by-example/blob/master/error-handling-example/error-server/src/main/java/com/example/grpc/server/UnknownStatusDescriptionInterceptor.java +public class TransmitUnexpectedExceptionInterceptor implements ServerInterceptor { + + private final Set> exactTypes = new HashSet<>(); + private final Set> parentTypes = new HashSet<>(); + + /** + * Allows this interceptor to match on an exact exception type. + * @param exactType The exact type to match on. + * @return this + */ + public TransmitUnexpectedExceptionInterceptor forExactType(Class exactType) { + this.exactTypes.add(exactType); + return this; + } + + /** + * Allows this interceptor to match on a set of exact exception type. + * @param exactTypes The set of exact types to match on. + * @return this + */ + public TransmitUnexpectedExceptionInterceptor forExactTypes(Collection> exactTypes) { + this.exactTypes.addAll(exactTypes); + return this; + } + + /** + * Allows this interceptor to match on any exception type deriving from {@code parentType}. + * @param parentType The parent type to match on. + * @return this + */ + public TransmitUnexpectedExceptionInterceptor forParentType(Class parentType) { + this.parentTypes.add(parentType); + return this; + } + + /** + * Allows this interceptor to match on any exception type deriving from any element of {@code parentTypes}. + * @param parentTypes The set of parent types to match on. + * @return this + */ + public TransmitUnexpectedExceptionInterceptor forParentTypes(Collection> parentTypes) { + this.parentTypes.addAll(parentTypes); + return this; + } + + /** + * Allows this interceptor to match all exceptions. Use with caution! + * @return this + */ + public TransmitUnexpectedExceptionInterceptor forAllExceptions() { + return forParentType(Throwable.class); + } + + @Override + public ServerCall.Listener interceptCall(ServerCall call, Metadata headers, ServerCallHandler next) { + ServerCall wrappedCall = new ForwardingServerCall.SimpleForwardingServerCall(call) { + @Override + public void sendMessage(RespT message) { + super.sendMessage(message); + } + + @Override + public void close(Status status, Metadata trailers) { + + if (status.getCode() == Status.Code.UNKNOWN && + status.getDescription() == null && + status.getCause() != null && + exceptionTypeIsAllowed(status.getCause().getClass())) { + Throwable e = status.getCause(); + status = Status.INTERNAL + .withDescription(e.getMessage()) + .augmentDescription(stacktraceToString(e)); + } + super.close(status, trailers); + } + }; + + return new ForwardingServerCallListener.SimpleForwardingServerCallListener(next.startCall(wrappedCall, headers)) { + @Override + public void onHalfClose() { + try { + super.onHalfClose(); + } catch (Throwable e) { + if (exceptionTypeIsAllowed(e.getClass())) { + call.close(Status.INTERNAL + .withDescription(e.getMessage()) + .augmentDescription(stacktraceToString(e)), new Metadata()); + } else { + throw e; + } + } + } + }; + } + + private boolean exceptionTypeIsAllowed(Class exceptionClass) { + // exact matches + for (Class clazz : exactTypes) { + if (clazz.equals(exceptionClass)) { + return true; + } + } + + // parent type matches + for (Class clazz : parentTypes) { + if (clazz.isAssignableFrom(exceptionClass)) { + return true; + } + } + + // no match + return false; + } + + private String stacktraceToString(Throwable e) { + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + e.printStackTrace(printWriter); + return stringWriter.toString(); + } +} \ No newline at end of file diff --git a/contrib/grpc-contrib/src/test/java/com/salesforce/grpc/contrib/interceptor/TransmitUnexpectedExceptionInterceptorTest.java b/contrib/grpc-contrib/src/test/java/com/salesforce/grpc/contrib/interceptor/TransmitUnexpectedExceptionInterceptorTest.java new file mode 100644 index 00000000..9cdcde4c --- /dev/null +++ b/contrib/grpc-contrib/src/test/java/com/salesforce/grpc/contrib/interceptor/TransmitUnexpectedExceptionInterceptorTest.java @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2019, Salesforce.com, Inc. + * All rights reserved. + * Licensed under the BSD 3-Clause license. + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + */ + +package com.salesforce.grpc.contrib.interceptor; + +import com.salesforce.grpc.contrib.GreeterGrpc; +import com.salesforce.grpc.contrib.HelloRequest; +import com.salesforce.grpc.contrib.HelloResponse; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcServerRule; +import org.junit.Rule; +import org.junit.Test; + +import java.util.Iterator; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TransmitUnexpectedExceptionInterceptorTest { + @Rule + public final GrpcServerRule serverRule = new GrpcServerRule(); + + @Test + public void noExceptionDoesNotInterfere() { + GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() { + @Override + public void sayHello(HelloRequest request, StreamObserver responseObserver) { + responseObserver.onNext(HelloResponse.newBuilder().setMessage("Hello " + request.getName()).build()); + responseObserver.onCompleted(); + } + }; + + ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor(); + + serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor)); + GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel()); + + stub.sayHello(HelloRequest.newBuilder().setName("World").build()); + } + + @Test + public void exactTypeMatches() { + GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() { + @Override + public void sayHello(HelloRequest request, StreamObserver responseObserver) { + responseObserver.onError(new ArithmeticException("Divide by zero")); + } + }; + + ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(ArithmeticException.class); + + serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor)); + GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel()); + + assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build())) + .isInstanceOf(StatusRuntimeException.class) + .matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL") + .hasMessageContaining("Divide by zero"); + } + + @Test + public void parentTypeMatches() { + GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() { + @Override + public void sayHello(HelloRequest request, StreamObserver responseObserver) { + responseObserver.onError(new ArithmeticException("Divide by zero")); + } + }; + + ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forParentType(RuntimeException.class); + + serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor)); + GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel()); + + assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build())) + .isInstanceOf(StatusRuntimeException.class) + .matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL") + .hasMessageContaining("Divide by zero"); + } + + @Test + public void parentTypeMatchesExactly() { + GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() { + @Override + public void sayHello(HelloRequest request, StreamObserver responseObserver) { + responseObserver.onError(new RuntimeException("Divide by zero")); + } + }; + + ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forParentType(RuntimeException.class); + + serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor)); + GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel()); + + assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build())) + .isInstanceOf(StatusRuntimeException.class) + .matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL") + .hasMessageContaining("Divide by zero"); + } + + @Test + public void alleMatches() { + GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() { + @Override + public void sayHello(HelloRequest request, StreamObserver responseObserver) { + responseObserver.onError(new ArithmeticException("Divide by zero")); + } + }; + + ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forAllExceptions(); + + serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor)); + GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel()); + + assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build())) + .isInstanceOf(StatusRuntimeException.class) + .matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL") + .hasMessageContaining("Divide by zero"); + } + + @Test + public void unknownTypeDoesNotMatch() { + GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() { + @Override + public void sayHello(HelloRequest request, StreamObserver responseObserver) { + responseObserver.onError(new NullPointerException("Bananas!")); + } + }; + + ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(ArithmeticException.class); + + serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor)); + GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel()); + + assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build())) + .isInstanceOf(StatusRuntimeException.class) + .matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.UNKNOWN.getCode()), "is Status.UNKNOWN") + .hasMessageContaining("UNKNOWN"); + } + + @Test + public void unexpectedExceptionCanMatch() { + GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() { + @Override + public void sayHello(HelloRequest request, StreamObserver responseObserver) { + throw new ArithmeticException("Divide by zero"); + } + }; + + ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(ArithmeticException.class); + + serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor)); + GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel()); + + assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build())) + .isInstanceOf(StatusRuntimeException.class) + .matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL") + .hasMessageContaining("Divide by zero"); + } + + @Test + public void unexpectedExceptionCanNotMatch() { + GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() { + @Override + public void sayHello(HelloRequest request, StreamObserver responseObserver) { + throw new ArithmeticException("Divide by zero"); + } + }; + + ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(NullPointerException.class); + + serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor)); + GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel()); + + assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build())) + .isInstanceOf(StatusRuntimeException.class) + .matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.UNKNOWN.getCode()), "is Status.UNKNOWN") + .hasMessageContaining("UNKNOWN"); + } + + @Test + public void unexpectedExceptionCanMatchStreaming() { + GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() { + @Override + public void sayHelloStream(HelloRequest request, StreamObserver responseObserver) { + responseObserver.onNext(HelloResponse.getDefaultInstance()); + responseObserver.onNext(HelloResponse.getDefaultInstance()); + throw new ArithmeticException("Divide by zero"); + } + }; + + ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(ArithmeticException.class); + + serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor)); + GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel()); + + Iterator it = stub.sayHelloStream(HelloRequest.newBuilder().setName("World").build()); + it.next(); + it.next(); + assertThatThrownBy(it::next) + .isInstanceOf(StatusRuntimeException.class) + .matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL") + .hasMessageContaining("Divide by zero"); + } +} diff --git a/contrib/grpc-contrib/src/test/proto/helloworld.proto b/contrib/grpc-contrib/src/test/proto/helloworld.proto index d1abddce..a68cb492 100644 --- a/contrib/grpc-contrib/src/test/proto/helloworld.proto +++ b/contrib/grpc-contrib/src/test/proto/helloworld.proto @@ -14,6 +14,9 @@ service Greeter { // Sends a greeting rpc SayHello (HelloRequest) returns (HelloResponse) {} + // Sends many greetings + rpc SayHelloStream (HelloRequest) returns (stream HelloResponse) {} + // Sends the current time rpc SayTime (google.protobuf.Empty) returns (TimeResponse) {} }