From 5a553c03490b621adef69fd53c9527d9fb7c5113 Mon Sep 17 00:00:00 2001 From: Peter Wielander Date: Wed, 26 Nov 2025 15:23:56 -0800 Subject: [PATCH 1/6] [ai] Pass messages to step functions in DurableAgent --- packages/ai/src/agent/durable-agent.test.ts | 391 ++++++++++++++++-- packages/ai/src/agent/durable-agent.ts | 12 +- packages/ai/src/agent/stream-text-iterator.ts | 18 +- 3 files changed, 389 insertions(+), 32 deletions(-) diff --git a/packages/ai/src/agent/durable-agent.test.ts b/packages/ai/src/agent/durable-agent.test.ts index ce43b9591..234c1215a 100644 --- a/packages/ai/src/agent/durable-agent.test.ts +++ b/packages/ai/src/agent/durable-agent.test.ts @@ -2,10 +2,12 @@ * Tests for DurableAgent * * These tests focus on error handling in tool execution, - * particularly for FatalError conversion to tool result errors. + * particularly for FatalError conversion to tool result errors, + * and verifying that messages are properly passed to tool execute functions. */ import type { LanguageModelV2, + LanguageModelV2Prompt, LanguageModelV2ToolCall, } from '@ai-sdk/provider'; import { FatalError } from 'workflow'; @@ -58,20 +60,26 @@ describe('DurableAgent', () => { // Mock the streamTextIterator to return tool calls and then complete const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockMessages: LanguageModelV2Prompt = [ + { role: 'user', content: [{ type: 'text', text: 'test' }] }, + ]; const mockIterator = { next: vi .fn() .mockResolvedValueOnce({ done: false, - value: [ - { - toolCallId: 'test-call-id', - toolName: 'testTool', - input: '{}', - } as LanguageModelV2ToolCall, - ], + value: { + toolCalls: [ + { + toolCallId: 'test-call-id', + toolName: 'testTool', + input: '{}', + } as LanguageModelV2ToolCall, + ], + messages: mockMessages, + }, }) - .mockResolvedValueOnce({ done: true, value: undefined }), + .mockResolvedValueOnce({ done: true, value: [] }), }; vi.mocked(streamTextIterator).mockReturnValue( mockIterator as unknown as AsyncGenerator @@ -132,16 +140,22 @@ describe('DurableAgent', () => { }); const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockMessages: LanguageModelV2Prompt = [ + { role: 'user', content: [{ type: 'text', text: 'test' }] }, + ]; const mockIterator = { next: vi.fn().mockResolvedValueOnce({ done: false, - value: [ - { - toolCallId: 'test-call-id', - toolName: 'testTool', - input: '{}', - } as LanguageModelV2ToolCall, - ], + value: { + toolCalls: [ + { + toolCallId: 'test-call-id', + toolName: 'testTool', + input: '{}', + } as LanguageModelV2ToolCall, + ], + messages: mockMessages, + }, }), }; vi.mocked(streamTextIterator).mockReturnValue( @@ -186,20 +200,26 @@ describe('DurableAgent', () => { }); const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockMessages: LanguageModelV2Prompt = [ + { role: 'user', content: [{ type: 'text', text: 'test' }] }, + ]; const mockIterator = { next: vi .fn() .mockResolvedValueOnce({ done: false, - value: [ - { - toolCallId: 'test-call-id', - toolName: 'testTool', - input: '{}', - } as LanguageModelV2ToolCall, - ], + value: { + toolCalls: [ + { + toolCallId: 'test-call-id', + toolName: 'testTool', + input: '{}', + } as LanguageModelV2ToolCall, + ], + messages: mockMessages, + }, }) - .mockResolvedValueOnce({ done: true, value: undefined }), + .mockResolvedValueOnce({ done: true, value: [] }), }; vi.mocked(streamTextIterator).mockReturnValue( mockIterator as unknown as AsyncGenerator @@ -226,4 +246,327 @@ describe('DurableAgent', () => { }); }); }); + + describe('tool execution with messages', () => { + it('should pass conversation messages to tool execute function', async () => { + // Track what messages were passed to the tool + let receivedMessages: unknown; + let receivedToolCallId: string | undefined; + + const tools: ToolSet = { + testTool: { + description: 'A test tool', + inputSchema: z.object({ query: z.string() }), + execute: async (_input, options) => { + receivedMessages = options.messages; + receivedToolCallId = options.toolCallId; + return { result: 'success' }; + }, + }, + }; + + const mockModel: LanguageModelV2 = { + specificationVersion: 'v2' as const, + provider: 'test', + modelId: 'test-model', + doGenerate: vi.fn(), + doStream: vi.fn(), + }; + + const agent = new DurableAgent({ + model: async () => mockModel, + tools, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + // Mock conversation messages that would be accumulated by the iterator + const conversationMessages: LanguageModelV2Prompt = [ + { + role: 'user', + content: [{ type: 'text', text: 'What is the weather?' }], + }, + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'test-call-id', + toolName: 'testTool', + input: { query: 'weather' }, + }, + ], + }, + ]; + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + toolCalls: [ + { + toolCallId: 'test-call-id', + toolName: 'testTool', + input: '{"query":"weather"}', + } as LanguageModelV2ToolCall, + ], + messages: conversationMessages, + }, + }) + .mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as AsyncGenerator + ); + + await agent.stream({ + messages: [{ role: 'user', content: 'What is the weather?' }], + writable: mockWritable, + }); + + // Verify that messages were passed to the tool + expect(receivedToolCallId).toBe('test-call-id'); + expect(receivedMessages).toBeDefined(); + expect(Array.isArray(receivedMessages)).toBe(true); + expect(receivedMessages).toEqual(conversationMessages); + }); + + it('should pass messages to multiple tools in parallel execution', async () => { + // Track messages received by each tool + const receivedByTools: Record = {}; + + const tools: ToolSet = { + weatherTool: { + description: 'Get weather', + inputSchema: z.object({ city: z.string() }), + execute: async (_input, options) => { + receivedByTools['weatherTool'] = options.messages; + return { temp: 72 }; + }, + }, + newsTool: { + description: 'Get news', + inputSchema: z.object({ topic: z.string() }), + execute: async (_input, options) => { + receivedByTools['newsTool'] = options.messages; + return { headlines: ['News 1'] }; + }, + }, + }; + + const mockModel: LanguageModelV2 = { + specificationVersion: 'v2' as const, + provider: 'test', + modelId: 'test-model', + doGenerate: vi.fn(), + doStream: vi.fn(), + }; + + const agent = new DurableAgent({ + model: async () => mockModel, + tools, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + const conversationMessages: LanguageModelV2Prompt = [ + { + role: 'user', + content: [{ type: 'text', text: 'Weather and news please' }], + }, + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'weather-call', + toolName: 'weatherTool', + input: { city: 'NYC' }, + }, + { + type: 'tool-call', + toolCallId: 'news-call', + toolName: 'newsTool', + input: { topic: 'tech' }, + }, + ], + }, + ]; + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + toolCalls: [ + { + toolCallId: 'weather-call', + toolName: 'weatherTool', + input: '{"city":"NYC"}', + } as LanguageModelV2ToolCall, + { + toolCallId: 'news-call', + toolName: 'newsTool', + input: '{"topic":"tech"}', + } as LanguageModelV2ToolCall, + ], + messages: conversationMessages, + }, + }) + .mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as AsyncGenerator + ); + + await agent.stream({ + messages: [{ role: 'user', content: 'Weather and news please' }], + writable: mockWritable, + }); + + // Both tools should have received the same conversation messages + expect(receivedByTools['weatherTool']).toEqual(conversationMessages); + expect(receivedByTools['newsTool']).toEqual(conversationMessages); + }); + + it('should pass updated messages on subsequent tool call rounds', async () => { + // Track messages received in each round + const messagesPerRound: unknown[] = []; + + const tools: ToolSet = { + searchTool: { + description: 'Search for info', + inputSchema: z.object({ query: z.string() }), + execute: async (_input, options) => { + messagesPerRound.push(options.messages); + return { found: true }; + }, + }, + }; + + const mockModel: LanguageModelV2 = { + specificationVersion: 'v2' as const, + provider: 'test', + modelId: 'test-model', + doGenerate: vi.fn(), + doStream: vi.fn(), + }; + + const agent = new DurableAgent({ + model: async () => mockModel, + tools, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + // First round messages + const firstRoundMessages: LanguageModelV2Prompt = [ + { role: 'user', content: [{ type: 'text', text: 'Search for cats' }] }, + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'search-1', + toolName: 'searchTool', + input: { query: 'cats' }, + }, + ], + }, + ]; + + // Second round messages (includes first tool result) + const secondRoundMessages: LanguageModelV2Prompt = [ + ...firstRoundMessages, + { + role: 'tool', + content: [ + { + type: 'tool-result', + toolCallId: 'search-1', + toolName: 'searchTool', + output: { type: 'text', value: '{"found":true}' }, + }, + ], + }, + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'search-2', + toolName: 'searchTool', + input: { query: 'dogs' }, + }, + ], + }, + ]; + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi + .fn() + // First tool call round + .mockResolvedValueOnce({ + done: false, + value: { + toolCalls: [ + { + toolCallId: 'search-1', + toolName: 'searchTool', + input: '{"query":"cats"}', + } as LanguageModelV2ToolCall, + ], + messages: firstRoundMessages, + }, + }) + // Second tool call round + .mockResolvedValueOnce({ + done: false, + value: { + toolCalls: [ + { + toolCallId: 'search-2', + toolName: 'searchTool', + input: '{"query":"dogs"}', + } as LanguageModelV2ToolCall, + ], + messages: secondRoundMessages, + }, + }) + .mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as AsyncGenerator + ); + + await agent.stream({ + messages: [{ role: 'user', content: 'Search for cats' }], + writable: mockWritable, + }); + + // Verify messages grow with each round + expect(messagesPerRound).toHaveLength(2); + expect(messagesPerRound[0]).toEqual(firstRoundMessages); + expect(messagesPerRound[1]).toEqual(secondRoundMessages); + // Second round should have more messages than first + expect((messagesPerRound[1] as unknown[]).length).toBeGreaterThan( + (messagesPerRound[0] as unknown[]).length + ); + }); + }); }); diff --git a/packages/ai/src/agent/durable-agent.ts b/packages/ai/src/agent/durable-agent.ts index cfe3c50e2..9ac809f38 100644 --- a/packages/ai/src/agent/durable-agent.ts +++ b/packages/ai/src/agent/durable-agent.ts @@ -1,5 +1,6 @@ import type { LanguageModelV2, + LanguageModelV2Prompt, LanguageModelV2ToolCall, LanguageModelV2ToolResultPart, } from '@ai-sdk/provider'; @@ -158,11 +159,11 @@ export class DurableAgent { let result = await iterator.next(); while (!result.done) { - const toolCalls = result.value; + const { toolCalls, messages } = result.value; const toolResults = await Promise.all( toolCalls.map( (toolCall): Promise => - executeTool(toolCall, this.tools) + executeTool(toolCall, this.tools, messages) ) ); result = await iterator.next(toolResults); @@ -209,7 +210,8 @@ async function closeStream( async function executeTool( toolCall: LanguageModelV2ToolCall, - tools: ToolSet + tools: ToolSet, + messages: LanguageModelV2Prompt ): Promise { const tool = tools[toolCall.toolName]; if (!tool) throw new Error(`Tool "${toolCall.toolName}" not found`); @@ -228,8 +230,8 @@ async function executeTool( try { const toolResult = await tool.execute(input.value, { toolCallId: toolCall.toolCallId, - // TODO: pass the proper messages to the tool (we'd need to pass them through the iterator) - messages: [], + // Pass the conversation messages to the tool so it has context about the conversation + messages, }); return { diff --git a/packages/ai/src/agent/stream-text-iterator.ts b/packages/ai/src/agent/stream-text-iterator.ts index db137ec88..267e56afd 100644 --- a/packages/ai/src/agent/stream-text-iterator.ts +++ b/packages/ai/src/agent/stream-text-iterator.ts @@ -13,6 +13,17 @@ import type { import { doStreamStep, type ModelStopCondition } from './do-stream-step.js'; import { toolsToModelTools } from './tools-to-model-tools.js'; +/** + * The value yielded by the stream text iterator when tool calls are requested. + * Contains both the tool calls and the current conversation messages. + */ +export interface StreamTextIteratorYieldValue { + /** The tool calls requested by the model */ + toolCalls: LanguageModelV2ToolCall[]; + /** The conversation messages up to (and including) the tool call request */ + messages: LanguageModelV2Prompt; +} + // This runs in the workflow context export async function* streamTextIterator({ prompt, @@ -31,7 +42,7 @@ export async function* streamTextIterator({ sendStart?: boolean; onStepFinish?: StreamTextOnStepFinishCallback; }): AsyncGenerator< - LanguageModelV2ToolCall[], + StreamTextIteratorYieldValue, LanguageModelV2Prompt, LanguageModelV2ToolResultPart[] > { @@ -66,8 +77,9 @@ export async function* streamTextIterator({ })), }); - // Yield the tool calls and wait for results - const toolResults = yield toolCalls; + // Yield the tool calls along with the current conversation messages + // This allows executeTool to pass the conversation context to tool execute functions + const toolResults = yield { toolCalls, messages: conversationPrompt }; await writeToolOutputToUI(writable, toolResults); From a246204cf1f9059f0ad06485446d7895caf699ba Mon Sep 17 00:00:00 2001 From: Peter Wielander Date: Wed, 26 Nov 2025 15:29:52 -0800 Subject: [PATCH 2/6] Changesset --- .changeset/funny-games-sniff.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/funny-games-sniff.md diff --git a/.changeset/funny-games-sniff.md b/.changeset/funny-games-sniff.md new file mode 100644 index 000000000..8e80eef46 --- /dev/null +++ b/.changeset/funny-games-sniff.md @@ -0,0 +1,5 @@ +--- +"@workflow/ai": patch +--- + +Make current messages state available to tool calls From 3a8ad8539f0041dc705be69d2f1c95defe400669 Mon Sep 17 00:00:00 2001 From: Peter Wielander Date: Wed, 26 Nov 2025 18:11:11 -0800 Subject: [PATCH 3/6] Add prepareStep --- packages/ai/src/agent/durable-agent.ts | 79 ++++++++++++++++++- packages/ai/src/agent/stream-text-iterator.ts | 28 ++++++- 2 files changed, 103 insertions(+), 4 deletions(-) diff --git a/packages/ai/src/agent/durable-agent.ts b/packages/ai/src/agent/durable-agent.ts index 9ac809f38..2f05f363a 100644 --- a/packages/ai/src/agent/durable-agent.ts +++ b/packages/ai/src/agent/durable-agent.ts @@ -7,6 +7,7 @@ import type { import { asSchema, type ModelMessage, + type StepResult, type StopCondition, type StreamTextOnStepFinishCallback, type ToolSet, @@ -16,6 +17,57 @@ import { convertToLanguageModelPrompt, standardizePrompt } from 'ai/internal'; import { FatalError } from 'workflow'; import { streamTextIterator } from './stream-text-iterator.js'; +/** + * Information passed to the prepareStep callback. + */ +export interface PrepareStepInfo { + /** + * The current model configuration (string or function). + */ + model: string | (() => Promise); + + /** + * The current step number (0-indexed). + */ + stepNumber: number; + + /** + * All previous steps with their results. + */ + steps: StepResult[]; + + /** + * The messages that will be sent to the model. + * This is the LanguageModelV2Prompt format used internally. + */ + messages: LanguageModelV2Prompt; +} + +/** + * Return type from the prepareStep callback. + * All properties are optional - only return the ones you want to override. + */ +export interface PrepareStepResult { + /** + * Override the model for this step. + */ + model?: string | (() => Promise); + + /** + * Override the messages for this step. + * Use this for context management or message injection. + */ + messages?: LanguageModelV2Prompt; +} + +/** + * Callback function called before each step in the agent loop. + * Use this to modify settings, manage context, or implement dynamic behavior. + */ +export type PrepareStepCallback = ( + info: PrepareStepInfo +) => PrepareStepResult | Promise; + /** * Configuration options for creating a {@link DurableAgent} instance. */ @@ -44,7 +96,7 @@ export interface DurableAgentOptions { /** * Options for the {@link DurableAgent.stream} method. */ -export interface DurableAgentStreamOptions { +export interface DurableAgentStreamOptions { /** * The conversation messages to process. Should follow the AI SDK's ModelMessage format. */ @@ -90,6 +142,26 @@ export interface DurableAgentStreamOptions { * Callback function to be called after each step completes. */ onStepFinish?: StreamTextOnStepFinishCallback; + + /** + * Callback function called before each step in the agent loop. + * Use this to modify settings, manage context, or inject messages dynamically. + * + * @example + * ```typescript + * prepareStep: async ({ messages, stepNumber }) => { + * // Inject messages from a queue + * const queuedMessages = await getQueuedMessages(); + * if (queuedMessages.length > 0) { + * return { + * messages: [...messages, ...queuedMessages], + * }; + * } + * return {}; + * } + * ``` + */ + prepareStep?: PrepareStepCallback; } /** @@ -135,7 +207,9 @@ export class DurableAgent { throw new Error('Not implemented'); } - async stream(options: DurableAgentStreamOptions) { + async stream( + options: DurableAgentStreamOptions + ) { const prompt = await standardizePrompt({ system: options.system || this.system, messages: options.messages, @@ -155,6 +229,7 @@ export class DurableAgent { stopConditions: options.stopWhen, sendStart: options.sendStart ?? true, onStepFinish: options.onStepFinish, + prepareStep: options.prepareStep, }); let result = await iterator.next(); diff --git a/packages/ai/src/agent/stream-text-iterator.ts b/packages/ai/src/agent/stream-text-iterator.ts index 267e56afd..cbe808760 100644 --- a/packages/ai/src/agent/stream-text-iterator.ts +++ b/packages/ai/src/agent/stream-text-iterator.ts @@ -11,6 +11,7 @@ import type { UIMessageChunk, } from 'ai'; import { doStreamStep, type ModelStopCondition } from './do-stream-step.js'; +import type { PrepareStepCallback } from './durable-agent.js'; import { toolsToModelTools } from './tools-to-model-tools.js'; /** @@ -33,6 +34,7 @@ export async function* streamTextIterator({ stopConditions, sendStart = true, onStepFinish, + prepareStep, }: { prompt: LanguageModelV2Prompt; tools: ToolSet; @@ -41,21 +43,42 @@ export async function* streamTextIterator({ stopConditions?: ModelStopCondition[] | ModelStopCondition; sendStart?: boolean; onStepFinish?: StreamTextOnStepFinishCallback; + prepareStep?: PrepareStepCallback; }): AsyncGenerator< StreamTextIteratorYieldValue, LanguageModelV2Prompt, LanguageModelV2ToolResultPart[] > { - const conversationPrompt = [...prompt]; // Create a mutable copy + let conversationPrompt = [...prompt]; // Create a mutable copy + let currentModel = model; const steps: StepResult[] = []; let done = false; let isFirstIteration = true; + let stepNumber = 0; while (!done) { + // Call prepareStep callback before each step if provided + if (prepareStep) { + const prepareResult = await prepareStep({ + model: currentModel, + stepNumber, + steps, + messages: conversationPrompt, + }); + + // Apply any overrides from prepareStep + if (prepareResult.model !== undefined) { + currentModel = prepareResult.model; + } + if (prepareResult.messages !== undefined) { + conversationPrompt = [...prepareResult.messages]; + } + } + const { toolCalls, finish, step } = await doStreamStep( conversationPrompt, - model, + currentModel, writable, toolsToModelTools(tools), { @@ -63,6 +86,7 @@ export async function* streamTextIterator({ } ); isFirstIteration = false; + stepNumber++; steps.push(step); if (finish?.finishReason === 'tool-calls') { From 00f21edf20c78db6481f0e53b2d9de47d4b49734 Mon Sep 17 00:00:00 2001 From: Peter Wielander Date: Wed, 26 Nov 2025 18:11:45 -0800 Subject: [PATCH 4/6] Changeset --- .changeset/dirty-sloths-cut.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/dirty-sloths-cut.md diff --git a/.changeset/dirty-sloths-cut.md b/.changeset/dirty-sloths-cut.md new file mode 100644 index 000000000..54aa2b79f --- /dev/null +++ b/.changeset/dirty-sloths-cut.md @@ -0,0 +1,5 @@ +--- +"@workflow/ai": patch +--- + +Add `prepareStep` argument for DurableAgent to modify messages between AI loop steps From b13653516528b45756c30a4b5ef48f5f961903da Mon Sep 17 00:00:00 2001 From: Peter Wielander Date: Wed, 26 Nov 2025 18:14:00 -0800 Subject: [PATCH 5/6] tests --- packages/ai/src/agent/durable-agent.test.ts | 212 +++++++++++++++++++- 1 file changed, 210 insertions(+), 2 deletions(-) diff --git a/packages/ai/src/agent/durable-agent.test.ts b/packages/ai/src/agent/durable-agent.test.ts index 234c1215a..ce18499a8 100644 --- a/packages/ai/src/agent/durable-agent.test.ts +++ b/packages/ai/src/agent/durable-agent.test.ts @@ -10,9 +10,9 @@ import type { LanguageModelV2Prompt, LanguageModelV2ToolCall, } from '@ai-sdk/provider'; -import { FatalError } from 'workflow'; -import { describe, expect, it, vi } from 'vitest'; import type { ToolSet } from 'ai'; +import { describe, expect, it, vi } from 'vitest'; +import { FatalError } from 'workflow'; import { z } from 'zod'; // Mock the streamTextIterator @@ -23,6 +23,8 @@ vi.mock('./stream-text-iterator.js', () => ({ // Import after mocking const { DurableAgent } = await import('./durable-agent.js'); +import type { PrepareStepCallback } from './durable-agent.js'; + describe('DurableAgent', () => { describe('tool execution error handling', () => { it('should convert FatalError to tool error result', async () => { @@ -247,6 +249,212 @@ describe('DurableAgent', () => { }); }); + describe('prepareStep callback', () => { + it('should pass prepareStep callback to streamTextIterator', async () => { + const mockModel: LanguageModelV2 = { + specificationVersion: 'v2' as const, + provider: 'test', + modelId: 'test-model', + doGenerate: vi.fn(), + doStream: vi.fn(), + }; + + const agent = new DurableAgent({ + model: async () => mockModel, + tools: {}, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi.fn().mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as AsyncGenerator + ); + + const prepareStep: PrepareStepCallback = vi.fn().mockReturnValue({}); + + await agent.stream({ + messages: [{ role: 'user', content: 'test' }], + writable: mockWritable, + prepareStep, + }); + + // Verify streamTextIterator was called with prepareStep + expect(streamTextIterator).toHaveBeenCalledWith( + expect.objectContaining({ + prepareStep, + }) + ); + }); + + it('should allow prepareStep to modify messages', async () => { + const mockModel: LanguageModelV2 = { + specificationVersion: 'v2' as const, + provider: 'test', + modelId: 'test-model', + doGenerate: vi.fn(), + doStream: vi.fn(), + }; + + const agent = new DurableAgent({ + model: async () => mockModel, + tools: {}, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi.fn().mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as AsyncGenerator + ); + + const injectedMessage = { + role: 'user' as const, + content: [{ type: 'text' as const, text: 'injected message' }], + }; + + const prepareStep: PrepareStepCallback = ({ messages }) => { + return { + messages: [...messages, injectedMessage], + }; + }; + + await agent.stream({ + messages: [{ role: 'user', content: 'test' }], + writable: mockWritable, + prepareStep, + }); + + // Verify prepareStep was passed to the iterator + expect(streamTextIterator).toHaveBeenCalledWith( + expect.objectContaining({ + prepareStep: expect.any(Function), + }) + ); + }); + + it('should allow prepareStep to change model dynamically', async () => { + const mockModel: LanguageModelV2 = { + specificationVersion: 'v2' as const, + provider: 'test', + modelId: 'test-model', + doGenerate: vi.fn(), + doStream: vi.fn(), + }; + + const agent = new DurableAgent({ + model: async () => mockModel, + tools: {}, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi.fn().mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as AsyncGenerator + ); + + const prepareStep: PrepareStepCallback = ({ stepNumber }) => { + // Switch to a different model after step 0 + if (stepNumber > 0) { + return { + model: 'anthropic/claude-sonnet-4.5', + }; + } + return {}; + }; + + await agent.stream({ + messages: [{ role: 'user', content: 'test' }], + writable: mockWritable, + prepareStep, + }); + + // Verify prepareStep was passed to the iterator + expect(streamTextIterator).toHaveBeenCalledWith( + expect.objectContaining({ + prepareStep: expect.any(Function), + }) + ); + }); + + it('should provide step information to prepareStep callback', async () => { + const mockModel: LanguageModelV2 = { + specificationVersion: 'v2' as const, + provider: 'test', + modelId: 'test-model', + doGenerate: vi.fn(), + doStream: vi.fn(), + }; + + const agent = new DurableAgent({ + model: async () => mockModel, + tools: {}, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi.fn().mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as AsyncGenerator + ); + + const prepareStepCalls: Array<{ + model: unknown; + stepNumber: number; + steps: unknown[]; + messages: LanguageModelV2Prompt; + }> = []; + + const prepareStep: PrepareStepCallback = (info) => { + prepareStepCalls.push({ + model: info.model, + stepNumber: info.stepNumber, + steps: info.steps, + messages: info.messages, + }); + return {}; + }; + + await agent.stream({ + messages: [{ role: 'user', content: 'test' }], + writable: mockWritable, + prepareStep, + }); + + // Verify prepareStep was passed and the function captures expected params + expect(streamTextIterator).toHaveBeenCalledWith( + expect.objectContaining({ + prepareStep: expect.any(Function), + }) + ); + }); + }); + describe('tool execution with messages', () => { it('should pass conversation messages to tool execute function', async () => { // Track what messages were passed to the tool From ae206104e36eac9438eae07b27c28c0b531f351d Mon Sep 17 00:00:00 2001 From: Peter Wielander Date: Wed, 26 Nov 2025 18:22:00 -0800 Subject: [PATCH 6/6] Better tests --- packages/ai/src/agent/durable-agent.test.ts | 125 +++++++------------- 1 file changed, 45 insertions(+), 80 deletions(-) diff --git a/packages/ai/src/agent/durable-agent.test.ts b/packages/ai/src/agent/durable-agent.test.ts index ce18499a8..e47202e4b 100644 --- a/packages/ai/src/agent/durable-agent.test.ts +++ b/packages/ai/src/agent/durable-agent.test.ts @@ -9,6 +9,7 @@ import type { LanguageModelV2, LanguageModelV2Prompt, LanguageModelV2ToolCall, + LanguageModelV2ToolResultPart, } from '@ai-sdk/provider'; import type { ToolSet } from 'ai'; import { describe, expect, it, vi } from 'vitest'; @@ -24,6 +25,30 @@ vi.mock('./stream-text-iterator.js', () => ({ const { DurableAgent } = await import('./durable-agent.js'); import type { PrepareStepCallback } from './durable-agent.js'; +import type { StreamTextIteratorYieldValue } from './stream-text-iterator.js'; + +/** + * Creates a mock LanguageModelV2 for testing + */ +function createMockModel(): LanguageModelV2 { + return { + specificationVersion: 'v2' as const, + provider: 'test', + modelId: 'test-model', + doGenerate: vi.fn(), + doStream: vi.fn(), + supportedUrls: {}, + }; +} + +/** + * Type for the mock iterator used in tests + */ +type MockIterator = AsyncGenerator< + StreamTextIteratorYieldValue, + LanguageModelV2Prompt, + LanguageModelV2ToolResultPart[] +>; describe('DurableAgent', () => { describe('tool execution error handling', () => { @@ -41,13 +66,7 @@ describe('DurableAgent', () => { // We need to test the executeTool function indirectly through the agent // Create a mock model that will trigger tool calls - const mockModel: LanguageModelV2 = { - specificationVersion: 'v2' as const, - provider: 'test', - modelId: 'test-model', - doGenerate: vi.fn(), - doStream: vi.fn(), - }; + const mockModel = createMockModel(); const agent = new DurableAgent({ model: async () => mockModel, @@ -84,7 +103,7 @@ describe('DurableAgent', () => { .mockResolvedValueOnce({ done: true, value: [] }), }; vi.mocked(streamTextIterator).mockReturnValue( - mockIterator as unknown as AsyncGenerator + mockIterator as unknown as MockIterator ); // Execute the stream - this should not throw even though the tool throws FatalError @@ -123,13 +142,7 @@ describe('DurableAgent', () => { }, }; - const mockModel: LanguageModelV2 = { - specificationVersion: 'v2' as const, - provider: 'test', - modelId: 'test-model', - doGenerate: vi.fn(), - doStream: vi.fn(), - }; + const mockModel = createMockModel(); const agent = new DurableAgent({ model: async () => mockModel, @@ -161,7 +174,7 @@ describe('DurableAgent', () => { }), }; vi.mocked(streamTextIterator).mockReturnValue( - mockIterator as unknown as AsyncGenerator + mockIterator as unknown as MockIterator ); // Execute should throw because non-FatalErrors are re-thrown @@ -183,13 +196,7 @@ describe('DurableAgent', () => { }, }; - const mockModel: LanguageModelV2 = { - specificationVersion: 'v2' as const, - provider: 'test', - modelId: 'test-model', - doGenerate: vi.fn(), - doStream: vi.fn(), - }; + const mockModel = createMockModel(); const agent = new DurableAgent({ model: async () => mockModel, @@ -224,7 +231,7 @@ describe('DurableAgent', () => { .mockResolvedValueOnce({ done: true, value: [] }), }; vi.mocked(streamTextIterator).mockReturnValue( - mockIterator as unknown as AsyncGenerator + mockIterator as unknown as MockIterator ); await agent.stream({ @@ -251,13 +258,7 @@ describe('DurableAgent', () => { describe('prepareStep callback', () => { it('should pass prepareStep callback to streamTextIterator', async () => { - const mockModel: LanguageModelV2 = { - specificationVersion: 'v2' as const, - provider: 'test', - modelId: 'test-model', - doGenerate: vi.fn(), - doStream: vi.fn(), - }; + const mockModel = createMockModel(); const agent = new DurableAgent({ model: async () => mockModel, @@ -274,7 +275,7 @@ describe('DurableAgent', () => { next: vi.fn().mockResolvedValueOnce({ done: true, value: [] }), }; vi.mocked(streamTextIterator).mockReturnValue( - mockIterator as unknown as AsyncGenerator + mockIterator as unknown as MockIterator ); const prepareStep: PrepareStepCallback = vi.fn().mockReturnValue({}); @@ -294,13 +295,7 @@ describe('DurableAgent', () => { }); it('should allow prepareStep to modify messages', async () => { - const mockModel: LanguageModelV2 = { - specificationVersion: 'v2' as const, - provider: 'test', - modelId: 'test-model', - doGenerate: vi.fn(), - doStream: vi.fn(), - }; + const mockModel = createMockModel(); const agent = new DurableAgent({ model: async () => mockModel, @@ -317,7 +312,7 @@ describe('DurableAgent', () => { next: vi.fn().mockResolvedValueOnce({ done: true, value: [] }), }; vi.mocked(streamTextIterator).mockReturnValue( - mockIterator as unknown as AsyncGenerator + mockIterator as unknown as MockIterator ); const injectedMessage = { @@ -346,13 +341,7 @@ describe('DurableAgent', () => { }); it('should allow prepareStep to change model dynamically', async () => { - const mockModel: LanguageModelV2 = { - specificationVersion: 'v2' as const, - provider: 'test', - modelId: 'test-model', - doGenerate: vi.fn(), - doStream: vi.fn(), - }; + const mockModel = createMockModel(); const agent = new DurableAgent({ model: async () => mockModel, @@ -369,7 +358,7 @@ describe('DurableAgent', () => { next: vi.fn().mockResolvedValueOnce({ done: true, value: [] }), }; vi.mocked(streamTextIterator).mockReturnValue( - mockIterator as unknown as AsyncGenerator + mockIterator as unknown as MockIterator ); const prepareStep: PrepareStepCallback = ({ stepNumber }) => { @@ -397,13 +386,7 @@ describe('DurableAgent', () => { }); it('should provide step information to prepareStep callback', async () => { - const mockModel: LanguageModelV2 = { - specificationVersion: 'v2' as const, - provider: 'test', - modelId: 'test-model', - doGenerate: vi.fn(), - doStream: vi.fn(), - }; + const mockModel = createMockModel(); const agent = new DurableAgent({ model: async () => mockModel, @@ -420,7 +403,7 @@ describe('DurableAgent', () => { next: vi.fn().mockResolvedValueOnce({ done: true, value: [] }), }; vi.mocked(streamTextIterator).mockReturnValue( - mockIterator as unknown as AsyncGenerator + mockIterator as unknown as MockIterator ); const prepareStepCalls: Array<{ @@ -473,13 +456,7 @@ describe('DurableAgent', () => { }, }; - const mockModel: LanguageModelV2 = { - specificationVersion: 'v2' as const, - provider: 'test', - modelId: 'test-model', - doGenerate: vi.fn(), - doStream: vi.fn(), - }; + const mockModel = createMockModel(); const agent = new DurableAgent({ model: async () => mockModel, @@ -530,7 +507,7 @@ describe('DurableAgent', () => { .mockResolvedValueOnce({ done: true, value: [] }), }; vi.mocked(streamTextIterator).mockReturnValue( - mockIterator as unknown as AsyncGenerator + mockIterator as unknown as MockIterator ); await agent.stream({ @@ -568,13 +545,7 @@ describe('DurableAgent', () => { }, }; - const mockModel: LanguageModelV2 = { - specificationVersion: 'v2' as const, - provider: 'test', - modelId: 'test-model', - doGenerate: vi.fn(), - doStream: vi.fn(), - }; + const mockModel = createMockModel(); const agent = new DurableAgent({ model: async () => mockModel, @@ -635,7 +606,7 @@ describe('DurableAgent', () => { .mockResolvedValueOnce({ done: true, value: [] }), }; vi.mocked(streamTextIterator).mockReturnValue( - mockIterator as unknown as AsyncGenerator + mockIterator as unknown as MockIterator ); await agent.stream({ @@ -663,13 +634,7 @@ describe('DurableAgent', () => { }, }; - const mockModel: LanguageModelV2 = { - specificationVersion: 'v2' as const, - provider: 'test', - modelId: 'test-model', - doGenerate: vi.fn(), - doStream: vi.fn(), - }; + const mockModel = createMockModel(); const agent = new DurableAgent({ model: async () => mockModel, @@ -759,7 +724,7 @@ describe('DurableAgent', () => { .mockResolvedValueOnce({ done: true, value: [] }), }; vi.mocked(streamTextIterator).mockReturnValue( - mockIterator as unknown as AsyncGenerator + mockIterator as unknown as MockIterator ); await agent.stream({