Skip to content
This repository has been archived by the owner on Mar 22, 2023. It is now read-only.

Commit

Permalink
adds support for propagating grpc server-side cancellations
Browse files Browse the repository at this point in the history
  • Loading branch information
shamsimam committed Jul 9, 2019
1 parent 5a0d19b commit 6dd5f97
Show file tree
Hide file tree
Showing 7 changed files with 502 additions and 93 deletions.
2 changes: 1 addition & 1 deletion containers/test-apps/courier/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>twosigma</groupId>
<artifactId>courier</artifactId>
<version>1.2.1</version>
<version>1.3.0</version>

<name>courier</name>
<url>https://github.com/twosigma/waiter/tree/master/test-apps/courier</url>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -148,6 +150,15 @@ public static CourierReply sendPackage(final String host,
} catch (final StatusRuntimeException e) {
logFunction.apply("RPC failed, status: " + e.getStatus());
return null;
} catch (ExecutionException e) {
final Status status = Status.fromThrowable(e.getCause());
logFunction.apply("RPC execution failed: " + status);
return CourierReply
.newBuilder()
.setId(status.getCode().toString())
.setMessage(status.getDescription())
.setResponse("error")
.build();
} catch (final Exception e) {
logFunction.apply("RPC failed, message: " + e.getMessage());
return null;
Expand All @@ -167,12 +178,14 @@ public static CourierReply sendPackage(final String host,
public static List<CourierSummary> collectPackages(final String host,
final int port,
final Map<String, Object> headers,
final String idPrefix,
final List<String> ids,
final String from,
final List<String> messages,
final int interMessageSleepMs,
final boolean lockStepMode) throws InterruptedException {
final boolean lockStepMode,
final int cancelThreshold) throws InterruptedException {
final ManagedChannel channel = initializeChannel(host, port);
final AtomicBoolean awaitChannelTermination = new AtomicBoolean(true);

try {
final Semaphore lockStep = new Semaphore(1);
Expand Down Expand Up @@ -214,6 +227,16 @@ public void onError(final Throwable throwable) {
logFunction.apply("releasing semaphore after receiving error");
lockStep.release();
}
if (throwable instanceof StatusRuntimeException) {
final StatusRuntimeException exception = (StatusRuntimeException) throwable;
final CourierSummary response = CourierSummary
.newBuilder()
.setNumMessages(0)
.setStatusCode(exception.getStatus().getCode().name())
.setStatusDescription(exception.getStatus().getDescription())
.build();
resultList.add(response);
}
}

@Override
Expand All @@ -229,11 +252,17 @@ private void resolveResponsePromise() {
});

for (int i = 0; i < messages.size(); i++) {
if (i >= cancelThreshold) {
logFunction.apply("cancelling sending messages");
awaitChannelTermination.set(false);
channel.shutdownNow();
throw new CancellationException("Cancel threshold reached: " + cancelThreshold);
}
if (errorSignal.get()) {
logFunction.apply("aborting sending messages as error was discovered");
break;
}
final String requestId = idPrefix + i;
final String requestId = ids.get(i);
if (lockStepMode) {
logFunction.apply("acquiring semaphore before sending request " + requestId);
lockStep.acquire();
Expand Down Expand Up @@ -264,39 +293,105 @@ private void resolveResponsePromise() {
}

} finally {
shutdownChannel(channel);
if (awaitChannelTermination.get()) {
shutdownChannel(channel);
}
}
}

public static List<CourierSummary> collectPackages(final String host,
final int port,
final Map<String, Object> headers,
final String idPrefix,
final String from,
final List<String> messages,
final int interMessageSleepMs,
final boolean lockStepMode,
final int cancelThreshold) throws InterruptedException {

final List<String> ids = new ArrayList<>(messages.size());
for (int i = 0; i < messages.size(); i++) {
ids.add(idPrefix + i);
}

return collectPackages(host, port, headers, ids, from, messages, interMessageSleepMs, lockStepMode, cancelThreshold);
}

/**
* Greet server. If provided, the first element of {@code args} is the name to use in the
* greeting.
*/
public static void main(final String... args) throws Exception {
/* Access a service running on the local machine on port 8080 */
final long startTimeMillis = System.currentTimeMillis();
final String host = "localhost";
final int port = 8080;
final HashMap<String, Object> headers = new HashMap<>();

final String id = UUID.randomUUID().toString();
final String user = "Jim";
final StringBuilder sb = new StringBuilder();
for (int i = 0; i < 100_000; i++) {
sb.append("a");
if (i % 1000 == 0) {
sb.append(".");
if (false) {
final String id = UUID.randomUUID().toString() + ".SEND_ERROR";
final String user = "Jim";
final StringBuilder sb = new StringBuilder();
for (int i = 0; i < 100_000; i++) {
sb.append("a");
if (i % 1000 == 0) {
sb.append(".");
}
}

headers.put("x-cid", "cid-send-package." + startTimeMillis);
final CourierReply courierReply = sendPackage(host, port, headers, id, user, sb.toString());
logFunction.apply("sendPackage response = " + courierReply);
}

if (false) {
headers.put("x-cid", "cid-collect-packages-complete." + startTimeMillis);
final List<String> messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
final List<CourierSummary> courierSummaries =
collectPackages(host, port, headers, "id-", "User", messages, 100, true, messages.size());
logFunction.apply("collectPackages response size = " + courierSummaries.size());
if (!courierSummaries.isEmpty()) {
final CourierSummary courierSummary = courierSummaries.get(courierSummaries.size() - 1);
logFunction.apply("collectPackages[complete] summary = " + courierSummary.toString());
}
}
final CourierReply courierReply = sendPackage(host, port, headers, id, user, sb.toString());
logFunction.apply("sendPackage response = " + courierReply);

final List<String> messages = IntStream.range(0, 100).mapToObj(i -> "message-" + i).collect(Collectors.toList());
final List<CourierSummary> courierSummaries =
collectPackages(host, port, headers, "id-", "User", messages, 10, true);
logFunction.apply("collectPackages response size = " + courierSummaries.size());
if (!courierSummaries.isEmpty()) {
final CourierSummary courierSummary = courierSummaries.get(courierSummaries.size() - 1);
logFunction.apply("collectPackages summary = " + courierSummary.toString());

if (false) {
headers.put("x-cid", "cid-collect-packages-cancel." + startTimeMillis);
final List<String> messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
final List<CourierSummary> courierSummaries =
collectPackages(host, port, headers, "id-", "User", messages, 100, true, messages.size() / 2);
logFunction.apply("collectPackages[cancel] summary = " + courierSummaries);
}

if (false) {
headers.put("x-cid", "cid-collect-packages-server-pre-cancel." + startTimeMillis);
final List<String> ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList());
ids.set(5, ids.get(5) + ".EXIT_PRE_RESPONSE");
final List<String> messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
final List<CourierSummary> courierSummaries =
collectPackages(host, port, headers, ids, "User", messages, 100, true, messages.size() + 1);
logFunction.apply("collectPackages[cancel] summary = " + courierSummaries);
}

if (false) {
headers.put("x-cid", "cid-collect-packages-server-post-cancel." + startTimeMillis);
final List<String> ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList());
ids.set(5, ids.get(5) + ".EXIT_POST_RESPONSE");
final List<String> messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
final List<CourierSummary> courierSummaries =
collectPackages(host, port, headers, ids, "User", messages, 100, true, messages.size() + 1);
logFunction.apply("collectPackages[cancel] summary = " + courierSummaries);
}

if (true) {
headers.put("x-cid", "cid-collect-packages-server-error." + startTimeMillis);
final List<String> ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList());
ids.set(5, ids.get(5) + ".SEND_ERROR");
final List<String> messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
final List<CourierSummary> courierSummaries =
collectPackages(host, port, headers, ids, "User", messages, 100, true, messages.size() + 1);
logFunction.apply("collectPackages[cancel] summary = " + courierSummaries);
}
}
}
Loading

0 comments on commit 6dd5f97

Please sign in to comment.