Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright (c) 2019, Salesforce.com, Inc.
* All rights reserved.
* Licensed under the BSD 3-Clause license.
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
*/

package com.salesforce.grpc.contrib.interceptor;

import io.grpc.*;

import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;

/**
* A class that intercepts uncaught exceptions of all types and handles them by closing the {@link ServerCall}, and
* transmitting the exception's description and stack trace to the client. This class is a complement to gRPC's
* {@code TransmitStatusRuntimeExceptionInterceptor}.
*
* <p>Without this interceptor, gRPC will strip all details and close the {@link ServerCall} with
* a generic {@link Status#UNKNOWN} code.
*
* <p>Security warning: the exception description and stack trace may contain sensitive server-side
* state information, and generally should not be sent to clients. Only install this interceptor
* if all clients are trusted.
*/
// Heavily inspired by https://github.com/saturnism/grpc-java-by-example/blob/master/error-handling-example/error-server/src/main/java/com/example/grpc/server/UnknownStatusDescriptionInterceptor.java
public class TransmitUnexpectedExceptionInterceptor implements ServerInterceptor {

private final Set<Class<? extends Throwable>> exactTypes = new HashSet<>();
private final Set<Class<? extends Throwable>> parentTypes = new HashSet<>();

/**
* Allows this interceptor to match on an exact exception type.
* @param exactType The exact type to match on.
* @return this
*/
public TransmitUnexpectedExceptionInterceptor forExactType(Class<? extends Throwable> exactType) {
this.exactTypes.add(exactType);
return this;
}

/**
* Allows this interceptor to match on a set of exact exception type.
* @param exactTypes The set of exact types to match on.
* @return this
*/
public TransmitUnexpectedExceptionInterceptor forExactTypes(Collection<Class<? extends Throwable>> exactTypes) {
this.exactTypes.addAll(exactTypes);
return this;
}

/**
* Allows this interceptor to match on any exception type deriving from {@code parentType}.
* @param parentType The parent type to match on.
* @return this
*/
public TransmitUnexpectedExceptionInterceptor forParentType(Class<? extends Throwable> parentType) {
this.parentTypes.add(parentType);
return this;
}

/**
* Allows this interceptor to match on any exception type deriving from any element of {@code parentTypes}.
* @param parentTypes The set of parent types to match on.
* @return this
*/
public TransmitUnexpectedExceptionInterceptor forParentTypes(Collection<Class<? extends Throwable>> parentTypes) {
this.parentTypes.addAll(parentTypes);
return this;
}

/**
* Allows this interceptor to match all exceptions. Use with caution!
* @return this
*/
public TransmitUnexpectedExceptionInterceptor forAllExceptions() {
return forParentType(Throwable.class);
}

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
ServerCall<ReqT, RespT> wrappedCall = new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(call) {
@Override
public void sendMessage(RespT message) {
super.sendMessage(message);
}

@Override
public void close(Status status, Metadata trailers) {

if (status.getCode() == Status.Code.UNKNOWN &&
status.getDescription() == null &&
status.getCause() != null &&
exceptionTypeIsAllowed(status.getCause().getClass())) {
Throwable e = status.getCause();
status = Status.INTERNAL
.withDescription(e.getMessage())
.augmentDescription(stacktraceToString(e));
}
super.close(status, trailers);
}
};

return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(next.startCall(wrappedCall, headers)) {
@Override
public void onHalfClose() {
try {
super.onHalfClose();
} catch (Throwable e) {
if (exceptionTypeIsAllowed(e.getClass())) {
call.close(Status.INTERNAL
.withDescription(e.getMessage())
.augmentDescription(stacktraceToString(e)), new Metadata());
} else {
throw e;
}
}
}
};
}

private boolean exceptionTypeIsAllowed(Class<? extends Throwable> exceptionClass) {
// exact matches
for (Class<? extends Throwable> clazz : exactTypes) {
if (clazz.equals(exceptionClass)) {
return true;
}
}

// parent type matches
for (Class<? extends Throwable> clazz : parentTypes) {
if (clazz.isAssignableFrom(exceptionClass)) {
return true;
}
}

// no match
return false;
}

private String stacktraceToString(Throwable e) {
StringWriter stringWriter = new StringWriter();
PrintWriter printWriter = new PrintWriter(stringWriter);
e.printStackTrace(printWriter);
return stringWriter.toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
/*
* Copyright (c) 2019, Salesforce.com, Inc.
* All rights reserved.
* Licensed under the BSD 3-Clause license.
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
*/

package com.salesforce.grpc.contrib.interceptor;

import com.salesforce.grpc.contrib.GreeterGrpc;
import com.salesforce.grpc.contrib.HelloRequest;
import com.salesforce.grpc.contrib.HelloResponse;
import io.grpc.ServerInterceptor;
import io.grpc.ServerInterceptors;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcServerRule;
import org.junit.Rule;
import org.junit.Test;

import java.util.Iterator;

import static org.assertj.core.api.Assertions.assertThatThrownBy;

public class TransmitUnexpectedExceptionInterceptorTest {
@Rule
public final GrpcServerRule serverRule = new GrpcServerRule();

@Test
public void noExceptionDoesNotInterfere() {
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
@Override
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
responseObserver.onNext(HelloResponse.newBuilder().setMessage("Hello " + request.getName()).build());
responseObserver.onCompleted();
}
};

ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor();

serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());

stub.sayHello(HelloRequest.newBuilder().setName("World").build());
}

@Test
public void exactTypeMatches() {
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
@Override
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
responseObserver.onError(new ArithmeticException("Divide by zero"));
}
};

ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(ArithmeticException.class);

serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());

assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build()))
.isInstanceOf(StatusRuntimeException.class)
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL")
.hasMessageContaining("Divide by zero");
}

@Test
public void parentTypeMatches() {
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
@Override
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
responseObserver.onError(new ArithmeticException("Divide by zero"));
}
};

ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forParentType(RuntimeException.class);

serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());

assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build()))
.isInstanceOf(StatusRuntimeException.class)
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL")
.hasMessageContaining("Divide by zero");
}

@Test
public void parentTypeMatchesExactly() {
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
@Override
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
responseObserver.onError(new RuntimeException("Divide by zero"));
}
};

ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forParentType(RuntimeException.class);

serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());

assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build()))
.isInstanceOf(StatusRuntimeException.class)
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL")
.hasMessageContaining("Divide by zero");
}

@Test
public void alleMatches() {
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
@Override
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
responseObserver.onError(new ArithmeticException("Divide by zero"));
}
};

ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forAllExceptions();

serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());

assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build()))
.isInstanceOf(StatusRuntimeException.class)
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL")
.hasMessageContaining("Divide by zero");
}

@Test
public void unknownTypeDoesNotMatch() {
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
@Override
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
responseObserver.onError(new NullPointerException("Bananas!"));
}
};

ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(ArithmeticException.class);

serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());

assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build()))
.isInstanceOf(StatusRuntimeException.class)
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.UNKNOWN.getCode()), "is Status.UNKNOWN")
.hasMessageContaining("UNKNOWN");
}

@Test
public void unexpectedExceptionCanMatch() {
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
@Override
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
throw new ArithmeticException("Divide by zero");
}
};

ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(ArithmeticException.class);

serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());

assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build()))
.isInstanceOf(StatusRuntimeException.class)
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL")
.hasMessageContaining("Divide by zero");
}

@Test
public void unexpectedExceptionCanNotMatch() {
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
@Override
public void sayHello(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
throw new ArithmeticException("Divide by zero");
}
};

ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(NullPointerException.class);

serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());

assertThatThrownBy(() -> stub.sayHello(HelloRequest.newBuilder().setName("World").build()))
.isInstanceOf(StatusRuntimeException.class)
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.UNKNOWN.getCode()), "is Status.UNKNOWN")
.hasMessageContaining("UNKNOWN");
}

@Test
public void unexpectedExceptionCanMatchStreaming() {
GreeterGrpc.GreeterImplBase svc = new GreeterGrpc.GreeterImplBase() {
@Override
public void sayHelloStream(HelloRequest request, StreamObserver<HelloResponse> responseObserver) {
responseObserver.onNext(HelloResponse.getDefaultInstance());
responseObserver.onNext(HelloResponse.getDefaultInstance());
throw new ArithmeticException("Divide by zero");
}
};

ServerInterceptor interceptor = new TransmitUnexpectedExceptionInterceptor().forExactType(ArithmeticException.class);

serverRule.getServiceRegistry().addService(ServerInterceptors.intercept(svc, interceptor));
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(serverRule.getChannel());

Iterator<HelloResponse> it = stub.sayHelloStream(HelloRequest.newBuilder().setName("World").build());
it.next();
it.next();
assertThatThrownBy(it::next)
.isInstanceOf(StatusRuntimeException.class)
.matches(sre -> ((StatusRuntimeException) sre).getStatus().getCode().equals(Status.INTERNAL.getCode()), "is Status.INTERNAL")
.hasMessageContaining("Divide by zero");
}
}
3 changes: 3 additions & 0 deletions contrib/grpc-contrib/src/test/proto/helloworld.proto
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ service Greeter {
// Sends a greeting
rpc SayHello (HelloRequest) returns (HelloResponse) {}

// Sends many greetings
rpc SayHelloStream (HelloRequest) returns (stream HelloResponse) {}

// Sends the current time
rpc SayTime (google.protobuf.Empty) returns (TimeResponse) {}
}
Expand Down