Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ import {
} from "./vars";
import {
workflowAccessActor,
workflowBatchJoinActor,
workflowCounterActor,
workflowQueueActor,
workflowSleepActor,
Expand Down Expand Up @@ -166,6 +167,7 @@ export const registry = setup({
workflowAccessActor,
workflowSleepActor,
workflowStopTeardownActor,
workflowBatchJoinActor,
// From actor-db-raw.ts
dbActorRaw,
// From actor-db-drizzle.ts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,64 @@ async function updateWorkflowAccessState(
typeof client.workflowQueueActor.getForId === "function";
}

const WORKFLOW_BATCH_QUEUE_NAME = "batch-requests";

export const workflowBatchJoinActor = actor({
state: {
processedRows: [] as number[],
processedCells: [] as string[],
requestsCompleted: 0,
},
queues: {
[WORKFLOW_BATCH_QUEUE_NAME]: queue<{ rowIds: number[] }>(),
},
run: workflow(async (ctx) => {
await ctx.loop("request-loop", async (loopCtx) => {
const request = await loopCtx.queue.next("wait-request", {
names: [WORKFLOW_BATCH_QUEUE_NAME],
});

const rowIds = request.body.rowIds;

// Fan out all rows in a single join.
const branches = Object.fromEntries(
rowIds.map((rowId, i) => [
`row-${i}`,
{
run: async (branchCtx: WorkflowLoopContextOf<typeof workflowBatchJoinActor>) => {
await branchCtx.step(`cell-a-${rowId}`, async () => {
branchCtx.state.processedCells.push(`a-${rowId}`);
});
await branchCtx.step(`cell-b-${rowId}`, async () => {
branchCtx.state.processedCells.push(`b-${rowId}`);
});
await branchCtx.step(`cell-c-${rowId}`, async () => {
branchCtx.state.processedRows.push(rowId);
branchCtx.state.processedCells.push(`c-${rowId}`);
});
},
},
]),
);

await loopCtx.join("process-rows", branches);

await loopCtx.step("request-done", async () => {
loopCtx.state.requestsCompleted += 1;
});

return Loop.continue(undefined);
});
}),
actions: {
getState: (c) => c.state,
},
});

function incrementWorkflowSleepTick(
ctx: WorkflowLoopContextOf<typeof workflowSleepActor>,
): void {
ctx.state.ticks += 1;
}

export { WORKFLOW_QUEUE_NAME };
export { WORKFLOW_QUEUE_NAME, WORKFLOW_BATCH_QUEUE_NAME };
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { describe, expect, test } from "vitest";
import {
WORKFLOW_BATCH_QUEUE_NAME,
WORKFLOW_QUEUE_NAME,
} from "../../../fixtures/driver-test-suite/workflow";
import type { DriverTestConfig } from "../mod";
Expand Down Expand Up @@ -110,6 +111,54 @@ export function runActorWorkflowTests(driverTestConfig: DriverTestConfig) {
},
);

test("join fans out rows in parallel inside loop", async (c) => {
const { client } = await setupDriverTest(c, driverTestConfig);
const actor = client.workflowBatchJoinActor.getOrCreate([
"workflow-batch-join",
]);

await actor.send(WORKFLOW_BATCH_QUEUE_NAME, {
rowIds: [1, 2, 3, 4],
});

let state = await actor.getState();
for (let i = 0; i < 50; i++) {
if (state.requestsCompleted >= 1) break;
await waitFor(driverTestConfig, 100);
state = await actor.getState();
}

expect(state.requestsCompleted).toBe(1);
expect(state.processedRows.sort()).toEqual([1, 2, 3, 4]);
expect(state.processedCells.length).toBe(12);
});

test("join handles sequential queue requests", async (c) => {
const { client } = await setupDriverTest(c, driverTestConfig);
const actor = client.workflowBatchJoinActor.getOrCreate([
"workflow-batch-join-sequential",
]);

await actor.send(WORKFLOW_BATCH_QUEUE_NAME, {
rowIds: [1, 2],
});
await actor.send(WORKFLOW_BATCH_QUEUE_NAME, {
rowIds: [3, 4],
});

let state = await actor.getState();
for (let i = 0; i < 50; i++) {
if (state.requestsCompleted >= 2) break;
await waitFor(driverTestConfig, 100);
state = await actor.getState();
}

expect(state.requestsCompleted).toBeGreaterThanOrEqual(2);
expect(state.processedRows).toEqual(
expect.arrayContaining([1, 2, 3, 4]),
);
});

// NOTE: Test for workflow persistence across actor sleep is complex because
// calling c.sleep() during a workflow prevents clean shutdown. The workflow
// persistence is implicitly tested by the "sleeps and resumes between ticks"
Expand Down
41 changes: 36 additions & 5 deletions rivetkit-typescript/packages/workflow-engine/src/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,31 @@ export class WorkflowContextImpl implements WorkflowContextInterface {
this.visitedKeys.add(key);
}

/**
* Merge visited keys from a child branch context into this context.
* This ensures that entries validated by nested branches are also
* recognized as visited by the parent scope's validateComplete.
*/
mergeVisitedKeys(child: WorkflowContextImpl): void {
for (const key of child.visitedKeys) {
this.visitedKeys.add(key);
}
}

/**
* Mark all history entries under a location prefix as visited.
* Used when replaying completed branches that are skipped during
* re-execution so their child entries don't trigger validateComplete errors.
*/
private markAllEntriesVisited(location: Location): void {
const prefix = locationToKey(this.storage, location);
for (const key of this.storage.history.entries.keys()) {
if (key.startsWith(prefix + "/") || key === prefix) {
this.visitedKeys.add(key);
}
}
}

/**
* Check if a name has already been used at the current location in this execution.
* Throws HistoryDivergedError if duplicate detected.
Expand Down Expand Up @@ -737,6 +762,7 @@ export class WorkflowContextImpl implements WorkflowContextInterface {

// Validate branch completed cleanly
branchCtx.validateComplete();
this.mergeVisitedKeys(branchCtx);

if ("break" in result && result.break) {
// Loop complete
Expand Down Expand Up @@ -1490,24 +1516,27 @@ export class WorkflowContextImpl implements WorkflowContextInterface {
async ([branchName, config]) => {
const branchStatus = joinData.branches[branchName];

const branchLocation = appendName(
this.storage,
location,
branchName,
);

// Already completed
if (branchStatus.status === "completed") {
this.markAllEntriesVisited(branchLocation);
results[branchName] = branchStatus.output;
return;
}

// Already failed
if (branchStatus.status === "failed") {
this.markAllEntriesVisited(branchLocation);
errors[branchName] = new Error(branchStatus.error);
return;
}

// Execute branch
const branchLocation = appendName(
this.storage,
location,
branchName,
);
const branchCtx = this.createBranch(branchLocation);

branchStatus.status = "running";
Expand All @@ -1516,6 +1545,7 @@ export class WorkflowContextImpl implements WorkflowContextInterface {
try {
const output = await config.run(branchCtx);
branchCtx.validateComplete();
this.mergeVisitedKeys(branchCtx);

branchStatus.status = "completed";
branchStatus.output = output;
Expand Down Expand Up @@ -1705,6 +1735,7 @@ export class WorkflowContextImpl implements WorkflowContextInterface {
winnerValue = output;

branchCtx.validateComplete();
this.mergeVisitedKeys(branchCtx);

branchStatus.status = "completed";
branchStatus.output = output;
Expand Down
Loading