Skip to content

Commit

Permalink
Make the ingress client async. Fix #247 (#252)
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper committed Mar 20, 2024
1 parent 3c69539 commit 4205690
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 57 deletions.
29 changes: 29 additions & 0 deletions sdk-api-gen/src/main/resources/templates/Client.hbs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@ public class {{generatedClassSimpleName}} {
{{outputSerdeFieldName}},
{{#if inputEmpty}}null{{else}}req{{/if}},
requestOptions);
}

public {{#if outputEmpty}}java.util.concurrent.CompletableFuture<Void>{{else}}java.util.concurrent.CompletableFuture<{{{boxedOutputFqcn}}}>{{/if}} {{name}}Async({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) {
return this.{{name}}Async(
{{^inputEmpty}}req, {{/inputEmpty}}
dev.restate.sdk.client.RequestOptions.DEFAULT);
}

public {{#if outputEmpty}}java.util.concurrent.CompletableFuture<Void>{{else}}java.util.concurrent.CompletableFuture<{{{boxedOutputFqcn}}}>{{/if}} {{name}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) {
return this.ingressClient.callAsync(
{{#if isObject}}Target.virtualObject(COMPONENT_NAME, this.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}},
{{inputSerdeFieldName}},
{{outputSerdeFieldName}},
{{#if inputEmpty}}null{{else}}req{{/if}},
requestOptions);
}{{/handlers}}

public Send send() {
Expand All @@ -129,6 +144,20 @@ public class {{generatedClassSimpleName}} {
{{inputSerdeFieldName}},
{{#if inputEmpty}}null{{else}}req{{/if}},
requestOptions);
}

public java.util.concurrent.CompletableFuture<String> {{name}}Async({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) {
return this.{{name}}Async(
{{^inputEmpty}}req, {{/inputEmpty}}
dev.restate.sdk.client.RequestOptions.DEFAULT);
}

public java.util.concurrent.CompletableFuture<String> {{name}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) {
return IngressClient.this.ingressClient.sendAsync(
{{#if isObject}}Target.virtualObject(COMPONENT_NAME, IngressClient.this.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}},
{{inputSerdeFieldName}},
{{#if inputEmpty}}null{{else}}req{{/if}},
requestOptions);
}{{/handlers}}
}
}
Expand Down
1 change: 1 addition & 0 deletions sdk-api-kotlin-gen/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies {
testImplementation(testingLibs.assertj)
testImplementation(coreLibs.protobuf.java)
testImplementation(coreLibs.log4j.core)
testImplementation(kotlinLibs.kotlinx.coroutines)

// Import test suites from sdk-core
testImplementation(project(":sdk-core", "testArchive"))
Expand Down
9 changes: 5 additions & 4 deletions sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import dev.restate.sdk.common.StateKey
import dev.restate.sdk.common.Serde
import dev.restate.sdk.common.Target
import kotlin.time.Duration
import kotlinx.coroutines.future.await

object {{generatedClassSimpleName}} {

Expand Down Expand Up @@ -73,12 +74,12 @@ object {{generatedClassSimpleName}} {

{{#handlers}}
suspend fun {{name}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}requestOptions: dev.restate.sdk.client.RequestOptions = dev.restate.sdk.client.RequestOptions.DEFAULT): {{{boxedOutputFqcn}}} {
return this.ingressClient.call(
return this.ingressClient.callAsync(
{{#if isObject}}Target.virtualObject(COMPONENT_NAME, this.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}},
{{inputSerdeFieldName}},
{{outputSerdeFieldName}},
{{#if inputEmpty}}null{{else}}req{{/if}},
requestOptions);
requestOptions).await();
}{{/handlers}}

fun send(): Send {
Expand All @@ -88,11 +89,11 @@ object {{generatedClassSimpleName}} {
inner class Send {
{{#handlers}}
suspend fun {{name}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}requestOptions: dev.restate.sdk.client.RequestOptions = dev.restate.sdk.client.RequestOptions.DEFAULT): String {
return this@IngressClient.ingressClient.send(
return this@IngressClient.ingressClient.sendAsync(
{{#if isObject}}Target.virtualObject(COMPONENT_NAME, this@IngressClient.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}},
{{inputSerdeFieldName}},
{{#if inputEmpty}}null{{else}}req{{/if}},
requestOptions);
requestOptions).await();
}{{/handlers}}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
import com.fasterxml.jackson.core.JsonToken;
import dev.restate.sdk.common.Serde;
import dev.restate.sdk.common.Target;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

public class DefaultIngressClient implements IngressClient {

Expand All @@ -37,53 +38,58 @@ public DefaultIngressClient(HttpClient httpClient, String baseUri, Map<String, S
}

@Override
public <Req, Res> Res call(
public <Req, Res> CompletableFuture<Res> callAsync(
Target target,
Serde<Req> reqSerde,
Serde<Res> resSerde,
Req req,
RequestOptions requestOptions) {
HttpRequest request = prepareHttpRequest(target, false, reqSerde, req, requestOptions);
HttpResponse<byte[]> response;
try {
response = httpClient.send(request, HttpResponse.BodyHandlers.ofByteArray());
} catch (IOException | InterruptedException e) {
throw new RuntimeException("Error when executing the request", e);
}

if (response.statusCode() != 200) {
// Try to parse as string
String error = new String(response.body(), StandardCharsets.UTF_8);
throw new RuntimeException(
"Received non OK status code: " + response.statusCode() + ". Body: " + error);
}

return resSerde.deserialize(response.body());
return httpClient
.sendAsync(request, HttpResponse.BodyHandlers.ofByteArray())
.handle(
(response, throwable) -> {
if (throwable != null) {
throw new IngressException("Error when executing the request", throwable);
}

if (response.statusCode() >= 300) {
handleNonSuccessResponse(response);
}

try {
return resSerde.deserialize(response.body());
} catch (Exception e) {
throw new IngressException(
"Cannot deserialize the response", response.statusCode(), response.body(), e);
}
});
}

@Override
public <Req> String send(Target target, Serde<Req> reqSerde, Req req, RequestOptions options) {
public <Req> CompletableFuture<String> sendAsync(
Target target, Serde<Req> reqSerde, Req req, RequestOptions options) {
HttpRequest request = prepareHttpRequest(target, true, reqSerde, req, options);
HttpResponse<InputStream> response;
try {
response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
} catch (IOException | InterruptedException e) {
throw new RuntimeException("Error when executing the request", e);
}

try (InputStream in = response.body()) {
if (response.statusCode() >= 300) {
// Try to parse as string
String error = new String(in.readAllBytes(), StandardCharsets.UTF_8);
throw new RuntimeException(
"Received non OK status code: " + response.statusCode() + ". Body: " + error);
}
return deserializeInvocationId(in);
} catch (IOException e) {
throw new RuntimeException(
"Error when trying to read the response, when status code was " + response.statusCode(),
e);
}
return httpClient
.sendAsync(request, HttpResponse.BodyHandlers.ofByteArray())
.handle(
(response, throwable) -> {
if (throwable != null) {
throw new IngressException("Error when executing the request", throwable);
}

if (response.statusCode() >= 300) {
handleNonSuccessResponse(response);
}

try {
return findStringFieldInJsonObject(
new ByteArrayInputStream(response.body()), "invocationId");
} catch (Exception e) {
throw new IngressException(
"Cannot deserialize the response", response.statusCode(), response.body(), e);
}
});
}

private URI toRequestURI(Target target, boolean isSend) {
Expand Down Expand Up @@ -128,23 +134,43 @@ private <Req> HttpRequest prepareHttpRequest(
return reqBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(reqSerde.serialize(req))).build();
}

private static String deserializeInvocationId(InputStream body) throws IOException {
private void handleNonSuccessResponse(HttpResponse<byte[]> response) {
if (response.headers().firstValue("content-type").orElse("").contains("application/json")) {
String errorMessage;
// Let's try to parse the message field
try {
errorMessage =
findStringFieldInJsonObject(new ByteArrayInputStream(response.body()), "message");
} catch (Exception e) {
throw new IngressException(
"Can't decode error response from ingress", response.statusCode(), response.body(), e);
}
throw new IngressException(errorMessage, response.statusCode(), response.body());
}

// Fallback error
throw new IngressException(
"Received non success status code", response.statusCode(), response.body());
}

private static String findStringFieldInJsonObject(InputStream body, String fieldName)
throws IOException {
try (JsonParser parser = JSON_FACTORY.createParser(body)) {
if (parser.nextToken() != JsonToken.START_OBJECT) {
throw new IllegalStateException(
"Expecting token " + JsonToken.START_OBJECT + ", got " + parser.getCurrentToken());
}
String fieldName = parser.nextFieldName();
if (fieldName == null || !fieldName.equalsIgnoreCase("invocationid")) {
throw new IllegalStateException(
"Expecting token \"invocationId\", got " + parser.getCurrentToken());
}
String invocationId = parser.nextTextValue();
if (invocationId == null) {
throw new IllegalStateException(
"Expecting token " + JsonToken.VALUE_STRING + ", got " + parser.getCurrentToken());
for (String actualFieldName = parser.nextFieldName();
actualFieldName != null;
actualFieldName = parser.nextFieldName()) {
if (actualFieldName.equalsIgnoreCase(fieldName)) {
return parser.nextTextValue();
} else {
parser.nextValue();
}
}
return invocationId;
throw new IllegalStateException(
"Expecting field name \"" + fieldName + "\", got " + parser.getCurrentToken());
}
}
}
47 changes: 43 additions & 4 deletions sdk-common/src/main/java/dev/restate/sdk/client/IngressClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,57 @@
import java.net.http.HttpClient;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;

public interface IngressClient {
<Req, Res> Res call(

<Req, Res> CompletableFuture<Res> callAsync(
Target target, Serde<Req> reqSerde, Serde<Res> resSerde, Req req, RequestOptions options);

default <Req, Res> Res call(Target target, Serde<Req> reqSerde, Serde<Res> resSerde, Req req) {
default <Req, Res> CompletableFuture<Res> callAsync(
Target target, Serde<Req> reqSerde, Serde<Res> resSerde, Req req) {
return callAsync(target, reqSerde, resSerde, req, RequestOptions.DEFAULT);
}

default <Req, Res> Res call(
Target target, Serde<Req> reqSerde, Serde<Res> resSerde, Req req, RequestOptions options)
throws IngressException {
try {
return callAsync(target, reqSerde, resSerde, req, options).join();
} catch (CompletionException e) {
if (e.getCause() instanceof RuntimeException) {
throw (RuntimeException) e.getCause();
}
throw new RuntimeException(e.getCause());
}
}

default <Req, Res> Res call(Target target, Serde<Req> reqSerde, Serde<Res> resSerde, Req req)
throws IngressException {
return call(target, reqSerde, resSerde, req, RequestOptions.DEFAULT);
}

<Req> String send(Target target, Serde<Req> reqSerde, Req req, RequestOptions options);
<Req> CompletableFuture<String> sendAsync(
Target target, Serde<Req> reqSerde, Req req, RequestOptions options);

default <Req> CompletableFuture<String> sendAsync(Target target, Serde<Req> reqSerde, Req req) {
return sendAsync(target, reqSerde, req, RequestOptions.DEFAULT);
}

default <Req> String send(Target target, Serde<Req> reqSerde, Req req, RequestOptions options)
throws IngressException {
try {
return sendAsync(target, reqSerde, req, options).join();
} catch (CompletionException e) {
if (e.getCause() instanceof RuntimeException) {
throw (RuntimeException) e.getCause();
}
throw new RuntimeException(e.getCause());
}
}

default <Req> String send(Target target, Serde<Req> reqSerde, Req req) {
default <Req> String send(Target target, Serde<Req> reqSerde, Req req) throws IngressException {
return send(target, reqSerde, req, RequestOptions.DEFAULT);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
//
// This file is part of the Restate Java SDK,
// which is released under the MIT license.
//
// You can find a copy of the license in file LICENSE in the root
// directory of this repository or package, or at
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.client;

import java.nio.charset.StandardCharsets;
import org.jspecify.annotations.Nullable;

public class IngressException extends RuntimeException {

private final int statusCode;
private final byte[] responseBody;

public IngressException(String message, Throwable cause) {
this(message, -1, null, cause);
}

public IngressException(String message, int statusCode, byte[] responseBody) {
this(message, statusCode, responseBody, null);
}

public IngressException(String message, int statusCode, byte[] responseBody, Throwable cause) {
super(message, cause);
this.statusCode = statusCode;
this.responseBody = responseBody;
}

public int getStatusCode() {
return statusCode;
}

public byte @Nullable [] getResponseBody() {
return responseBody;
}

@Override
public String toString() {
return "IngressException{"
+ "statusCode="
+ statusCode
+ ", responseBody='"
+ new String(responseBody, StandardCharsets.UTF_8)
+ '\''
+ ", message='"
+ this.getMessage()
+ '\''
+ '}';
}
}

0 comments on commit 4205690

Please sign in to comment.