Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Now send acks on combinators #334

Merged
merged 1 commit into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk-api/src/test/java/dev/restate/sdk/DeferredTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ protected TestInvocationBuilder awaitOnAlreadyResolvedAwaitables() {

protected TestInvocationBuilder awaitWithTimeout() {
return testDefinitionForService(
"AwaitOnAlreadyResolvedAwaitables",
"AwaitWithTimeout",
Serde.VOID,
JsonSerdes.STRING,
(ctx, unused) -> {
Expand Down
56 changes: 56 additions & 0 deletions sdk-core/src/main/java/dev/restate/sdk/core/AckStateMachine.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
//
// This file is part of the Restate Java SDK,
// which is released under the MIT license.
//
// You can find a copy of the license in file LICENSE in the root
// directory of this repository or package, or at
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.core;

/** State machine tracking acks */
class AckStateMachine extends BaseSuspendableCallbackStateMachine<AckStateMachine.AckCallback> {

interface AckCallback extends SuspendableCallback {
void onAck();
}

private int lastAcknowledgedEntry = -1;

/** -1 means no side effect waiting to be acked. */
private int lastEntryToAck = -1;

void waitLastAck(AckCallback callback) {
if (lastEntryIsAcked()) {
callback.onAck();
} else {
this.setCallback(callback);
}
}

void tryHandleAck(int entryIndex) {
this.lastAcknowledgedEntry = Math.max(entryIndex, this.lastAcknowledgedEntry);
if (lastEntryIsAcked()) {
this.consumeCallback(AckCallback::onAck);
}
}

void registerEntryToAck(int entryIndex) {
this.lastEntryToAck = Math.max(entryIndex, this.lastEntryToAck);
}

private boolean lastEntryIsAcked() {
return this.lastEntryToAck <= this.lastAcknowledgedEntry;
}

public int getLastEntryToAck() {
return lastEntryToAck;
}

@Override
void abort(Throwable cause) {
super.abort(cause);
// We can't do anything else if the input stream is closed, so we just fail the callback, if any
this.tryFailCallback();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class InvocationStateMachine implements InvocationFlow.InvocationProcessor {

// Buffering of messages and completions
private final IncomingEntriesStateMachine incomingEntriesStateMachine;
private final SideEffectAckStateMachine sideEffectAckStateMachine;
private final AckStateMachine ackStateMachine;
private final ReadyResultStateMachine readyResultStateMachine;

// Flow sub/pub
Expand All @@ -75,7 +75,7 @@ class InvocationStateMachine implements InvocationFlow.InvocationProcessor {

this.incomingEntriesStateMachine = new IncomingEntriesStateMachine();
this.readyResultStateMachine = new ReadyResultStateMachine();
this.sideEffectAckStateMachine = new SideEffectAckStateMachine();
this.ackStateMachine = new AckStateMachine();

this.afterStartCallback = new CallbackHandle<>();
}
Expand Down Expand Up @@ -142,8 +142,7 @@ public void onNext(InvocationFlow.InvocationInput invocationInput) {
// runtime.
this.readyResultStateMachine.offerCompletion((Protocol.CompletionMessage) msg);
} else if (msg instanceof Protocol.EntryAckMessage) {
this.sideEffectAckStateMachine.tryHandleSideEffectAck(
((Protocol.EntryAckMessage) msg).getEntryIndex());
this.ackStateMachine.tryHandleAck(((Protocol.EntryAckMessage) msg).getEntryIndex());
} else {
this.incomingEntriesStateMachine.offer(msg);
}
Expand All @@ -159,7 +158,7 @@ public void onError(Throwable throwable) {
public void onComplete() {
LOG.trace("Input publisher closed");
this.readyResultStateMachine.abort(AbortedExecutionException.INSTANCE);
this.sideEffectAckStateMachine.abort(AbortedExecutionException.INSTANCE);
this.ackStateMachine.abort(AbortedExecutionException.INSTANCE);
}

// --- Init routine to wait for the start message
Expand Down Expand Up @@ -287,7 +286,7 @@ private void closeWithMessage(MessageLite closeMessage, Throwable cause) {
// Unblock any eventual waiting callbacks
this.afterStartCallback.consume(cb -> cb.onCancel(cause));
this.readyResultStateMachine.abort(cause);
this.sideEffectAckStateMachine.abort(cause);
this.ackStateMachine.abort(cause);
this.incomingEntriesStateMachine.abort(cause);
this.span.end();
}
Expand Down Expand Up @@ -456,21 +455,21 @@ void exitSideEffectBlock(
}

// Write new entry
this.sideEffectAckStateMachine.registerExecutedSideEffect(this.currentJournalEntryIndex);
this.ackStateMachine.registerEntryToAck(this.currentJournalEntryIndex);
this.writeEntry(sideEffectEntry);

// Wait for entry to be acked
Protocol.RunEntryMessage finalSideEffectEntry = sideEffectEntry;
this.sideEffectAckStateMachine.waitLastSideEffectAck(
new SideEffectAckStateMachine.SideEffectAckCallback() {
this.ackStateMachine.waitLastAck(
new AckStateMachine.AckCallback() {
@Override
public void onLastSideEffectAck() {
public void onAck() {
completeSideEffectCallbackWithEntry(finalSideEffectEntry, callback);
}

@Override
public void onSuspend() {
suspend(List.of(sideEffectAckStateMachine.getLastExecutedSideEffect()));
suspend(List.of(ackStateMachine.getLastEntryToAck()));
callback.onCancel(AbortedExecutionException.INSTANCE);
}

Expand Down Expand Up @@ -621,8 +620,7 @@ private void resolveCombinatorDeferred(
+ "This is a symptom of an SDK bug, please contact the developers.");
}

writeCombinatorEntry(Collections.emptyList());
callback.onSuccess(null);
writeCombinatorEntry(Collections.emptyList(), callback);
return;
}

Expand All @@ -636,8 +634,7 @@ private void resolveCombinatorDeferred(

// Try to resolve the combinator now
if (rootDeferred.tryResolve(entryIndex)) {
writeCombinatorEntry(resolvedOrder);
callback.onSuccess(null);
writeCombinatorEntry(resolvedOrder, callback);
return;
}
} else {
Expand Down Expand Up @@ -667,8 +664,7 @@ public boolean onNewResult(Map<Integer, Result<?>> resultMap) {

// Try to resolve the combinator now
if (rootDeferred.tryResolve(entryIndex)) {
writeCombinatorEntry(resolvedOrder);
callback.onSuccess(null);
writeCombinatorEntry(resolvedOrder, callback);
return true;
}
}
Expand All @@ -694,12 +690,35 @@ public void onError(Throwable e) {
}
}

private void writeCombinatorEntry(List<Integer> resolvedList) {
private void writeCombinatorEntry(List<Integer> resolvedList, SyscallCallback<Void> callback) {
// Create and write the entry
Java.CombinatorAwaitableEntryMessage entry =
Java.CombinatorAwaitableEntryMessage.newBuilder().addAllEntryIndex(resolvedList).build();
span.addEvent("Combinator");

// We register the combinator entry to wait for an ack
this.ackStateMachine.registerEntryToAck(this.currentJournalEntryIndex);
writeEntry(entry);

// Let's wait for the ack
this.ackStateMachine.waitLastAck(
new AckStateMachine.AckCallback() {
@Override
public void onAck() {
callback.onSuccess(null);
}

@Override
public void onSuspend() {
suspend(List.of(ackStateMachine.getLastEntryToAck()));
callback.onCancel(AbortedExecutionException.INSTANCE);
}

@Override
public void onError(Throwable e) {
callback.onCancel(e);
}
});
}

// --- Internal callback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package dev.restate.sdk.core;

import com.google.protobuf.MessageLite;
import dev.restate.generated.sdk.java.Java;
import dev.restate.generated.service.protocol.Protocol;

public class MessageHeader {
Expand Down Expand Up @@ -82,6 +83,9 @@ public static MessageHeader fromMessage(MessageLite msg) {
} else if (msg instanceof Protocol.RunEntryMessage) {
return new MessageHeader(
MessageType.RunEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize());
} else if (msg instanceof Java.CombinatorAwaitableEntryMessage) {
return new MessageHeader(
MessageType.CombinatorAwaitableEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize());
}
// Messages with no flags
return new MessageHeader(MessageType.fromMessage(msg), 0, msg.getSerializedSize());
Expand Down

This file was deleted.

Loading
Loading