diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ingress.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ingress.kt index b2f67590b..c61bb3541 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ingress.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ingress.kt @@ -67,6 +67,18 @@ suspend fun Client.InvocationHandle.getOutputSuspend( return this.getOutputAsync(options).await() } +suspend fun Client.IdempotentInvocationHandle.attachSuspend( + options: RequestOptions = RequestOptions.DEFAULT +): T { + return this.attachAsync(options).await() +} + +suspend fun Client.IdempotentInvocationHandle.getOutputSuspend( + options: RequestOptions = RequestOptions.DEFAULT +): Output { + return this.getOutputAsync(options).await() +} + suspend fun Client.WorkflowHandle.attachSuspend( options: RequestOptions = RequestOptions.DEFAULT ): T { diff --git a/sdk-common/src/main/java/dev/restate/sdk/client/Client.java b/sdk-common/src/main/java/dev/restate/sdk/client/Client.java index 44a67a3ea..9d799571d 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/client/Client.java +++ b/sdk-common/src/main/java/dev/restate/sdk/client/Client.java @@ -211,6 +211,54 @@ default Output getOutput() throws IngressException { } } + IdempotentInvocationHandle idempotentInvocationHandle( + Target target, String idempotencyKey, Serde resSerde); + + interface IdempotentInvocationHandle { + + CompletableFuture attachAsync(RequestOptions options); + + default CompletableFuture attachAsync() { + return attachAsync(RequestOptions.DEFAULT); + } + + default Res attach(RequestOptions options) throws IngressException { + try { + return attachAsync(options).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + default Res attach() throws IngressException { + return attach(RequestOptions.DEFAULT); + } + + CompletableFuture> getOutputAsync(RequestOptions options); + + default CompletableFuture> getOutputAsync() { + return getOutputAsync(RequestOptions.DEFAULT); + } + + default Output getOutput(RequestOptions options) throws IngressException { + try { + return getOutputAsync(options).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + default Output getOutput() throws IngressException { + return getOutput(RequestOptions.DEFAULT); + } + } + WorkflowHandle workflowHandle( String workflowName, String workflowId, Serde resSerde); diff --git a/sdk-common/src/main/java/dev/restate/sdk/client/DefaultClient.java b/sdk-common/src/main/java/dev/restate/sdk/client/DefaultClient.java index 0710f2f22..bd4a73435 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/client/DefaultClient.java +++ b/sdk-common/src/main/java/dev/restate/sdk/client/DefaultClient.java @@ -26,6 +26,8 @@ import java.time.Duration; import java.util.*; import java.util.concurrent.CompletableFuture; +import java.util.function.BiFunction; +import org.jetbrains.annotations.NotNull; import org.jspecify.annotations.NonNull; public class DefaultClient implements Client { @@ -116,22 +118,30 @@ public CompletableFuture sendAsync( @Override public AwakeableHandle awakeableHandle(String id) { return new AwakeableHandle() { + private Void handleVoidResponse(HttpResponse response, Throwable throwable) { + if (throwable != null) { + throw new IngressException("Error when executing the request", throwable); + } + + if (response.statusCode() >= 300) { + handleNonSuccessResponse(response); + } + + return null; + } + @Override public CompletableFuture resolveAsync( Serde serde, @NonNull T payload, RequestOptions options) { // Prepare request var reqBuilder = - HttpRequest.newBuilder().uri(baseUri.resolve("/restate/awakeables/" + id + "/resolve")); + prepareBuilder(options).uri(baseUri.resolve("/restate/awakeables/" + id + "/resolve")); // Add content-type if (serde.contentType() != null) { reqBuilder.header("content-type", serde.contentType()); } - // Add headers - headers.forEach(reqBuilder::header); - options.getAdditionalHeaders().forEach(reqBuilder::header); - // Build and Send request HttpRequest request = reqBuilder @@ -139,18 +149,7 @@ public CompletableFuture resolveAsync( .build(); return httpClient .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle( - (response, throwable) -> { - if (throwable != null) { - throw new IngressException("Error when executing the request", throwable); - } - - if (response.statusCode() >= 300) { - handleNonSuccessResponse(response); - } - - return null; - }); + .handle(this::handleVoidResponse); } @Override @@ -169,18 +168,7 @@ public CompletableFuture rejectAsync(String reason, RequestOptions options HttpRequest request = reqBuilder.POST(HttpRequest.BodyPublishers.ofString(reason)).build(); return httpClient .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle( - (response, throwable) -> { - if (throwable != null) { - throw new IngressException("Error when executing the request", throwable); - } - - if (response.statusCode() >= 300) { - handleNonSuccessResponse(response); - } - - return null; - }); + .handle(this::handleVoidResponse); } }; } @@ -197,78 +185,72 @@ public String invocationId() { public CompletableFuture attachAsync(RequestOptions options) { // Prepare request var reqBuilder = - HttpRequest.newBuilder() + prepareBuilder(options) .uri(baseUri.resolve("/restate/invocation/" + invocationId + "/attach")); - // Add headers - headers.forEach(reqBuilder::header); - options.getAdditionalHeaders().forEach(reqBuilder::header); - // Build and Send request HttpRequest request = reqBuilder.GET().build(); return httpClient .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle( - (response, throwable) -> { - if (throwable != null) { - throw new IngressException("Error when executing the request", throwable); - } - - if (response.statusCode() >= 300) { - handleNonSuccessResponse(response); - } - - try { - return resSerde.deserialize(response.body()); - } catch (Exception e) { - throw new IngressException( - "Cannot deserialize the response", - response.statusCode(), - response.body(), - e); - } - }); + .handle(handleAttachResponse(resSerde)); } @Override public CompletableFuture> getOutputAsync(RequestOptions options) { // Prepare request var reqBuilder = - HttpRequest.newBuilder() + prepareBuilder(options) .uri(baseUri.resolve("/restate/invocation/" + invocationId + "/output")); - // Add headers - headers.forEach(reqBuilder::header); - options.getAdditionalHeaders().forEach(reqBuilder::header); + // Build and Send request + HttpRequest request = reqBuilder.GET().build(); + return httpClient + .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) + .handle(handleGetOutputResponse(resSerde)); + } + }; + } + + @Override + public IdempotentInvocationHandle idempotentInvocationHandle( + Target target, String idempotencyKey, Serde resSerde) { + return new IdempotentInvocationHandle<>() { + @Override + public CompletableFuture attachAsync(RequestOptions options) { + // Prepare request + var uri = + baseUri.resolve( + "/restate/invocation" + + targetToURI(target) + + "/" + + URLEncoder.encode(idempotencyKey, StandardCharsets.UTF_8) + + "/attach"); + var reqBuilder = prepareBuilder(options).uri(uri); + + // Build and Send request + HttpRequest request = reqBuilder.GET().build(); + return httpClient + .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) + .handle(handleAttachResponse(resSerde)); + } + + @Override + public CompletableFuture> getOutputAsync(RequestOptions options) { + // Prepare request + var uri = + baseUri.resolve( + "/restate/invocation" + + targetToURI(target) + + "/" + + URLEncoder.encode(idempotencyKey, StandardCharsets.UTF_8) + + "/output"); + var reqBuilder = prepareBuilder(options).uri(uri); // Build and Send request HttpRequest request = reqBuilder.GET().build(); return httpClient .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle( - (response, throwable) -> { - if (throwable != null) { - throw new IngressException("Error when executing the request", throwable); - } - - if (response.statusCode() == 470) { - return Output.notReady(); - } - - if (response.statusCode() >= 300) { - handleNonSuccessResponse(response); - } - - try { - return Output.ready(resSerde.deserialize(response.body())); - } catch (Exception e) { - throw new IngressException( - "Cannot deserialize the response", - response.statusCode(), - response.body(), - e); - } - }); + .handle(handleGetOutputResponse(resSerde)); } }; } @@ -281,7 +263,7 @@ public WorkflowHandle workflowHandle( public CompletableFuture attachAsync(RequestOptions options) { // Prepare request var reqBuilder = - HttpRequest.newBuilder() + prepareBuilder(options) .uri( baseUri.resolve( "/restate/workflow/" @@ -290,41 +272,18 @@ public CompletableFuture attachAsync(RequestOptions options) { + URLEncoder.encode(workflowId, StandardCharsets.UTF_8) + "/attach")); - // Add headers - headers.forEach(reqBuilder::header); - options.getAdditionalHeaders().forEach(reqBuilder::header); - // Build and Send request HttpRequest request = reqBuilder.GET().build(); return httpClient .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle( - (response, throwable) -> { - if (throwable != null) { - throw new IngressException("Error when executing the request", throwable); - } - - if (response.statusCode() >= 300) { - handleNonSuccessResponse(response); - } - - try { - return resSerde.deserialize(response.body()); - } catch (Exception e) { - throw new IngressException( - "Cannot deserialize the response", - response.statusCode(), - response.body(), - e); - } - }); + .handle(handleAttachResponse(resSerde)); } @Override public CompletableFuture> getOutputAsync(RequestOptions options) { // Prepare request var reqBuilder = - HttpRequest.newBuilder() + prepareBuilder(options) .uri( baseUri.resolve( "/restate/workflow/" @@ -333,49 +292,73 @@ public CompletableFuture> getOutputAsync(RequestOptions options) { + URLEncoder.encode(workflowId, StandardCharsets.UTF_8) + "/output")); - // Add headers - headers.forEach(reqBuilder::header); - options.getAdditionalHeaders().forEach(reqBuilder::header); - // Build and Send request HttpRequest request = reqBuilder.GET().build(); return httpClient .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle( - (response, throwable) -> { - if (throwable != null) { - throw new IngressException("Error when executing the request", throwable); - } - - if (response.statusCode() == 470) { - return Output.notReady(); - } - - if (response.statusCode() >= 300) { - handleNonSuccessResponse(response); - } - - try { - return Output.ready(resSerde.deserialize(response.body())); - } catch (Exception e) { - throw new IngressException( - "Cannot deserialize the response", - response.statusCode(), - response.body(), - e); - } - }); + .handle(handleGetOutputResponse(resSerde)); } }; } - private URI toRequestURI(Target target, boolean isSend, Duration delay) { + private @NotNull + BiFunction, Throwable, Output> handleGetOutputResponse( + Serde resSerde) { + return (response, throwable) -> { + if (throwable != null) { + throw new IngressException("Error when executing the request", throwable); + } + + if (response.statusCode() == 470) { + return Output.notReady(); + } + + if (response.statusCode() >= 300) { + handleNonSuccessResponse(response); + } + + try { + return Output.ready(resSerde.deserialize(response.body())); + } catch (Exception e) { + throw new IngressException( + "Cannot deserialize the response", response.statusCode(), response.body(), e); + } + }; + } + + private @NotNull BiFunction, Throwable, Res> handleAttachResponse( + Serde resSerde) { + return (response, throwable) -> { + if (throwable != null) { + throw new IngressException("Error when executing the request", throwable); + } + + if (response.statusCode() >= 300) { + handleNonSuccessResponse(response); + } + + try { + return resSerde.deserialize(response.body()); + } catch (Exception e) { + throw new IngressException( + "Cannot deserialize the response", response.statusCode(), response.body(), e); + } + }; + } + + /** Contains prefix / but not postfix / */ + private String targetToURI(Target target) { StringBuilder builder = new StringBuilder(); builder.append("/").append(target.getService()); if (target.getKey() != null) { builder.append("/").append(URLEncoder.encode(target.getKey(), StandardCharsets.UTF_8)); } builder.append("/").append(target.getHandler()); + return builder.toString(); + } + + private URI toRequestURI(Target target, boolean isSend, Duration delay) { + StringBuilder builder = new StringBuilder(targetToURI(target)); if (isSend) { builder.append("/send"); } @@ -386,19 +369,8 @@ private URI toRequestURI(Target target, boolean isSend, Duration delay) { return this.baseUri.resolve(builder.toString()); } - private HttpRequest prepareHttpRequest( - Target target, - boolean isSend, - Serde reqSerde, - Req req, - Duration delay, - RequestOptions options) { - var reqBuilder = HttpRequest.newBuilder().uri(toRequestURI(target, isSend, delay)); - - // Add content-type - if (reqSerde.contentType() != null) { - reqBuilder.header("content-type", reqSerde.contentType()); - } + private HttpRequest.Builder prepareBuilder(RequestOptions options) { + var reqBuilder = HttpRequest.newBuilder(); // Add headers this.headers.forEach(reqBuilder::header); @@ -413,6 +385,23 @@ private HttpRequest prepareHttpRequest( // Add additional headers options.getAdditionalHeaders().forEach(reqBuilder::header); + return reqBuilder; + } + + private HttpRequest prepareHttpRequest( + Target target, + boolean isSend, + Serde reqSerde, + Req req, + Duration delay, + RequestOptions options) { + var reqBuilder = prepareBuilder(options).uri(toRequestURI(target, isSend, delay)); + + // Add content-type + if (reqSerde.contentType() != null) { + reqBuilder.header("content-type", reqSerde.contentType()); + } + return reqBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(reqSerde.serialize(req))).build(); }