Skip to content
Permalink
Browse files

adds support for propagating grpc server-side cancellations (#844)

* adds support for propagating grpc server-side cancellations

* adds messages to test assertions

* closes trailers eagerly to allow pure header responses to be published asap

* removes unnecessary when clauses

* avoids overloading the id field in CourierRequest
introduces separate variant field to determine cancellation operation
adds unary server exit case
cleans up return values from the grpc client
adds assertion on grpc health check response body

* documents reasons for closing the channels

* adds tests for client-side streaming and server-side cancellations with unary responses

* initializes grpc client result Status to UNKNOWN instead of OK
asserts on individual fields on the ping response

* improves assertion messages
avoids warnings due to presence of overloaded methods in grpc client

* improves assertion messages
avoids warnings due to presence of overloaded methods in grpc client
adds correlation-id to logs from grpc client

* avoids race in reading status on error path in grpc-client
  • Loading branch information...
shamsimam authored and sradack committed Jul 11, 2019
1 parent c6b2af1 commit dd6bd41ecc1a7832e0051b3a7c6bb76967c063bc
@@ -6,7 +6,7 @@

<groupId>twosigma</groupId>
<artifactId>courier</artifactId>
<version>1.2.1</version>
<version>1.4.6</version>

<name>courier</name>
<url>https://github.com/twosigma/waiter/tree/master/test-apps/courier</url>

Large diffs are not rendered by default.

@@ -22,6 +22,9 @@
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;

import java.io.IOException;
@@ -34,6 +37,14 @@

private final static Logger LOGGER = Logger.getLogger(GrpcServer.class.getName());

private static void sleep(final int durationMillis) {
try {
Thread.sleep(durationMillis);
} catch (final Exception ex) {
ex.printStackTrace();
}
}

private Server server;

void start(final int port) throws IOException {
@@ -76,44 +87,89 @@ public void sendPackage(final CourierRequest request, final StreamObserver<Couri
"id=" + request.getId() + ", " +
"from=" + request.getFrom() + ", " +
"message.length=" + request.getMessage().length() + "}");
final CourierReply reply = CourierReply
.newBuilder()
.setId(request.getId())
.setMessage(request.getMessage())
.setResponse("received")
.build();
LOGGER.info("Sending CourierReply for id=" + reply.getId());
responseObserver.onNext(reply);
responseObserver.onCompleted();
if (responseObserver instanceof ServerCallStreamObserver) {
((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> {
LOGGER.info("CancelHandler:sendPackage CourierRequest{" + "id=" + request.getId() + "} was cancelled");
});
}
if (Variant.SEND_ERROR.equals(request.getVariant())) {
final StatusRuntimeException error = Status.CANCELLED
.withCause(new RuntimeException(request.getId()))
.withDescription("Cancelled by server")
.asRuntimeException();
LOGGER.info("Sending cancelled by server error");
responseObserver.onError(error);
} else if (Variant.EXIT_PRE_RESPONSE.equals(request.getVariant())) {
sleep(1000);
LOGGER.info("Exiting server abruptly");
System.exit(1);
} else {
final CourierReply reply = CourierReply
.newBuilder()
.setId(request.getId())
.setMessage(request.getMessage())
.setResponse("received")
.build();
LOGGER.info("Sending CourierReply for id=" + reply.getId());
responseObserver.onNext(reply);
responseObserver.onCompleted();
}
}

@Override
public StreamObserver<CourierRequest> collectPackages(final StreamObserver<CourierSummary> responseObserver) {

if (responseObserver instanceof ServerCallStreamObserver) {
((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> {
LOGGER.info("CancelHandler:collectPackages() was cancelled");
});
}
return new StreamObserver<CourierRequest>() {

private long numMessages = 0;
private long totalLength = 0;

@Override
public void onNext(final CourierRequest courierRequest) {
LOGGER.info("Received CourierRequest id=" + courierRequest.getId());
public void onNext(final CourierRequest request) {
LOGGER.info("Received CourierRequest id=" + request.getId());

numMessages += 1;
totalLength += courierRequest.getMessage().length();
totalLength += request.getMessage().length();
LOGGER.info("Summary of collected packages: numMessages=" + numMessages +
" with totalLength=" + totalLength);

final CourierSummary courierSummary = CourierSummary
.newBuilder()
.setNumMessages(numMessages)
.setTotalLength(totalLength)
.build();
LOGGER.info("Sending CourierSummary for id=" + courierRequest.getId());
responseObserver.onNext(courierSummary);
if (Variant.EXIT_PRE_RESPONSE.equals(request.getVariant())) {
sleep(1000);
LOGGER.info("Exiting server abruptly");
System.exit(1);
} else if (Variant.SEND_ERROR.equals(request.getVariant())) {
final StatusRuntimeException error = Status.CANCELLED
.withCause(new RuntimeException(request.getId()))
.withDescription("Cancelled by server")
.asRuntimeException();
LOGGER.info("Sending cancelled by server error");
responseObserver.onError(error);
} else {
final CourierSummary courierSummary = CourierSummary
.newBuilder()
.setNumMessages(numMessages)
.setTotalLength(totalLength)
.build();
LOGGER.info("Sending CourierSummary for id=" + request.getId());
responseObserver.onNext(courierSummary);
}

if (Variant.EXIT_POST_RESPONSE.equals(request.getVariant())) {
sleep(1000);
LOGGER.info("Exiting server abruptly");
System.exit(1);
}
}

@Override
public void onError(final Throwable throwable) {
LOGGER.severe("Error in collecting packages" + throwable.getMessage());
responseObserver.onError(throwable);
public void onError(final Throwable th) {
LOGGER.severe("Error in collecting packages: " + th.getMessage());
responseObserver.onError(th);
}

@Override
@@ -123,6 +179,63 @@ public void onCompleted() {
}
};
}

@Override
public StreamObserver<CourierRequest> aggregatePackages(final StreamObserver<CourierSummary> responseObserver) {

if (responseObserver instanceof ServerCallStreamObserver) {
((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> {
LOGGER.info("CancelHandler:collectPackages() was cancelled");
});
}
return new StreamObserver<CourierRequest>() {

private long numMessages = 0;
private long totalLength = 0;

@Override
public void onNext(final CourierRequest request) {
LOGGER.info("Received CourierRequest id=" + request.getId());

numMessages += 1;
totalLength += request.getMessage().length();
LOGGER.info("Summary of collected packages: numMessages=" + numMessages +
" with totalLength=" + totalLength);

if (Variant.EXIT_PRE_RESPONSE.equals(request.getVariant()) || Variant.EXIT_POST_RESPONSE.equals(request.getVariant())) {
sleep(1000);
LOGGER.info("Exiting server abruptly");
System.exit(1);
} else if (Variant.SEND_ERROR.equals(request.getVariant())) {
final StatusRuntimeException error = Status.CANCELLED
.withCause(new RuntimeException(request.getId()))
.withDescription("Cancelled by server")
.asRuntimeException();
LOGGER.info("Sending cancelled by server error");
responseObserver.onError(error);
}
}

@Override
public void onError(final Throwable th) {
LOGGER.severe("Error in aggregating packages: " + th.getMessage());
responseObserver.onError(th);
}

@Override
public void onCompleted() {
LOGGER.severe("Completed aggregating packages");
final CourierSummary courierSummary = CourierSummary
.newBuilder()
.setNumMessages(numMessages)
.setTotalLength(totalLength)
.build();
LOGGER.info("Sending aggregated CourierSummary");
responseObserver.onNext(courierSummary);
responseObserver.onCompleted();
}
};
}
}

private static class GrpcServerInterceptor implements ServerInterceptor {
@@ -142,20 +255,59 @@ public void onCompleted() {
new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(serverCall) {
@Override
public void sendHeaders(final Metadata responseHeaders) {
LOGGER.info("GrpcServerInterceptor.sendHeaders[cid=" + correlationId + "]");
logMetadata(requestMetadata, "response");
if (correlationId != null) {
LOGGER.info("response linked to cid: " + correlationId);
responseHeaders.put(xCidKey, correlationId);
}
super.sendHeaders(responseHeaders);
}

@Override
public void sendMessage(final RespT response) {
LOGGER.info("GrpcServerInterceptor.sendMessage[cid=" + correlationId + "]");
super.sendMessage(response);
}

@Override
public void close(final Status status, final Metadata trailers) {
LOGGER.info("GrpcServerInterceptor.close[cid=" + correlationId + "] " + status + ", " + trailers);
super.close(status, trailers);
}
};
return serverCallHandler.startCall(wrapperCall, requestMetadata);
final ServerCall.Listener<ReqT> listener = serverCallHandler.startCall(wrapperCall, requestMetadata);
return new ServerCall.Listener<ReqT>() {
public void onMessage(final ReqT message) {
LOGGER.info("GrpcServerInterceptor.onMessage[cid=" + correlationId + "]");
listener.onMessage(message);
}

public void onHalfClose() {
LOGGER.info("GrpcServerInterceptor.onHalfClose[cid=" + correlationId + "]");
listener.onHalfClose();
}

public void onCancel() {
LOGGER.info("GrpcServerInterceptor.onCancel[cid=" + correlationId + "]");
listener.onCancel();
}

public void onComplete() {
LOGGER.info("GrpcServerInterceptor.onComplete[cid=" + correlationId + "]");
listener.onComplete();
}

public void onReady() {
LOGGER.info("GrpcServerInterceptor.onReady[cid=" + correlationId + "]");
listener.onReady();
}
};
}

private void logMetadata(final Metadata metadata, final String label) {
final Set<String> metadataKeys = metadata.keys();
LOGGER.info(label + " metadata keys = " + metadataKeys);
LOGGER.info(label + "@" + metadata.hashCode() + " metadata keys = " + metadataKeys);
for (final String key : metadataKeys) {
final String value = metadata.get(Metadata.Key.of(key, ASCII_STRING_MARSHALLER));
LOGGER.info(label + " metadata " + key + " = " + value);
@@ -22,29 +22,35 @@ option objc_class_prefix = "TSCP";

package courier;

// The greeting service definition.
service Courier {
// Sends a package.
rpc SendPackage (CourierRequest) returns (CourierReply);
// Processes a stream of messages
// Processes a stream of messages, returns a stream of messages
rpc CollectPackages (stream CourierRequest) returns (stream CourierSummary);
// Processes a stream of messages, returns a unary message
rpc AggregatePackages (stream CourierRequest) returns (CourierSummary);
}

enum Variant {
NORMAL = 0;
SEND_ERROR = 1;
EXIT_PRE_RESPONSE = 2;
EXIT_POST_RESPONSE = 3;
}

// The request message containing the package's name.
message CourierRequest {
string id = 1;
string from = 2;
string message = 3;
Variant variant = 4;
}

// The response message containing the package response.
message CourierReply {
string id = 1;
string message = 2;
string response = 3;
}

// The response message containing the package response.
message CourierSummary {
int64 num_messages = 1;
int64 total_length = 2;

0 comments on commit dd6bd41

Please sign in to comment.
You can’t perform that action at this time.