diff --git a/.changeset/dirty-sloths-cut.md b/.changeset/dirty-sloths-cut.md new file mode 100644 index 000000000..54aa2b79f --- /dev/null +++ b/.changeset/dirty-sloths-cut.md @@ -0,0 +1,5 @@ +--- +"@workflow/ai": patch +--- + +Add `prepareStep` argument for DurableAgent to modify messages between AI loop steps diff --git a/.changeset/funny-games-sniff.md b/.changeset/funny-games-sniff.md new file mode 100644 index 000000000..8e80eef46 --- /dev/null +++ b/.changeset/funny-games-sniff.md @@ -0,0 +1,5 @@ +--- +"@workflow/ai": patch +--- + +Make current messages state available to tool calls diff --git a/packages/ai/src/agent/durable-agent.test.ts b/packages/ai/src/agent/durable-agent.test.ts index ce43b9591..e47202e4b 100644 --- a/packages/ai/src/agent/durable-agent.test.ts +++ b/packages/ai/src/agent/durable-agent.test.ts @@ -2,15 +2,18 @@ * Tests for DurableAgent * * These tests focus on error handling in tool execution, - * particularly for FatalError conversion to tool result errors. + * particularly for FatalError conversion to tool result errors, + * and verifying that messages are properly passed to tool execute functions. */ import type { LanguageModelV2, + LanguageModelV2Prompt, LanguageModelV2ToolCall, + LanguageModelV2ToolResultPart, } from '@ai-sdk/provider'; -import { FatalError } from 'workflow'; -import { describe, expect, it, vi } from 'vitest'; import type { ToolSet } from 'ai'; +import { describe, expect, it, vi } from 'vitest'; +import { FatalError } from 'workflow'; import { z } from 'zod'; // Mock the streamTextIterator @@ -21,6 +24,32 @@ vi.mock('./stream-text-iterator.js', () => ({ // Import after mocking const { DurableAgent } = await import('./durable-agent.js'); +import type { PrepareStepCallback } from './durable-agent.js'; +import type { StreamTextIteratorYieldValue } from './stream-text-iterator.js'; + +/** + * Creates a mock LanguageModelV2 for testing + */ +function createMockModel(): LanguageModelV2 { + return { + specificationVersion: 'v2' as const, + provider: 'test', + modelId: 'test-model', + doGenerate: vi.fn(), + doStream: vi.fn(), + supportedUrls: {}, + }; +} + +/** + * Type for the mock iterator used in tests + */ +type MockIterator = AsyncGenerator< + StreamTextIteratorYieldValue, + LanguageModelV2Prompt, + LanguageModelV2ToolResultPart[] +>; + describe('DurableAgent', () => { describe('tool execution error handling', () => { it('should convert FatalError to tool error result', async () => { @@ -37,13 +66,7 @@ describe('DurableAgent', () => { // We need to test the executeTool function indirectly through the agent // Create a mock model that will trigger tool calls - const mockModel: LanguageModelV2 = { - specificationVersion: 'v2' as const, - provider: 'test', - modelId: 'test-model', - doGenerate: vi.fn(), - doStream: vi.fn(), - }; + const mockModel = createMockModel(); const agent = new DurableAgent({ model: async () => mockModel, @@ -58,23 +81,29 @@ describe('DurableAgent', () => { // Mock the streamTextIterator to return tool calls and then complete const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockMessages: LanguageModelV2Prompt = [ + { role: 'user', content: [{ type: 'text', text: 'test' }] }, + ]; const mockIterator = { next: vi .fn() .mockResolvedValueOnce({ done: false, - value: [ - { - toolCallId: 'test-call-id', - toolName: 'testTool', - input: '{}', - } as LanguageModelV2ToolCall, - ], + value: { + toolCalls: [ + { + toolCallId: 'test-call-id', + toolName: 'testTool', + input: '{}', + } as LanguageModelV2ToolCall, + ], + messages: mockMessages, + }, }) - .mockResolvedValueOnce({ done: true, value: undefined }), + .mockResolvedValueOnce({ done: true, value: [] }), }; vi.mocked(streamTextIterator).mockReturnValue( - mockIterator as unknown as AsyncGenerator + mockIterator as unknown as MockIterator ); // Execute the stream - this should not throw even though the tool throws FatalError @@ -113,13 +142,7 @@ describe('DurableAgent', () => { }, }; - const mockModel: LanguageModelV2 = { - specificationVersion: 'v2' as const, - provider: 'test', - modelId: 'test-model', - doGenerate: vi.fn(), - doStream: vi.fn(), - }; + const mockModel = createMockModel(); const agent = new DurableAgent({ model: async () => mockModel, @@ -132,20 +155,26 @@ describe('DurableAgent', () => { }); const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockMessages: LanguageModelV2Prompt = [ + { role: 'user', content: [{ type: 'text', text: 'test' }] }, + ]; const mockIterator = { next: vi.fn().mockResolvedValueOnce({ done: false, - value: [ - { - toolCallId: 'test-call-id', - toolName: 'testTool', - input: '{}', - } as LanguageModelV2ToolCall, - ], + value: { + toolCalls: [ + { + toolCallId: 'test-call-id', + toolName: 'testTool', + input: '{}', + } as LanguageModelV2ToolCall, + ], + messages: mockMessages, + }, }), }; vi.mocked(streamTextIterator).mockReturnValue( - mockIterator as unknown as AsyncGenerator + mockIterator as unknown as MockIterator ); // Execute should throw because non-FatalErrors are re-thrown @@ -167,13 +196,7 @@ describe('DurableAgent', () => { }, }; - const mockModel: LanguageModelV2 = { - specificationVersion: 'v2' as const, - provider: 'test', - modelId: 'test-model', - doGenerate: vi.fn(), - doStream: vi.fn(), - }; + const mockModel = createMockModel(); const agent = new DurableAgent({ model: async () => mockModel, @@ -186,23 +209,29 @@ describe('DurableAgent', () => { }); const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockMessages: LanguageModelV2Prompt = [ + { role: 'user', content: [{ type: 'text', text: 'test' }] }, + ]; const mockIterator = { next: vi .fn() .mockResolvedValueOnce({ done: false, - value: [ - { - toolCallId: 'test-call-id', - toolName: 'testTool', - input: '{}', - } as LanguageModelV2ToolCall, - ], + value: { + toolCalls: [ + { + toolCallId: 'test-call-id', + toolName: 'testTool', + input: '{}', + } as LanguageModelV2ToolCall, + ], + messages: mockMessages, + }, }) - .mockResolvedValueOnce({ done: true, value: undefined }), + .mockResolvedValueOnce({ done: true, value: [] }), }; vi.mocked(streamTextIterator).mockReturnValue( - mockIterator as unknown as AsyncGenerator + mockIterator as unknown as MockIterator ); await agent.stream({ @@ -226,4 +255,491 @@ describe('DurableAgent', () => { }); }); }); + + describe('prepareStep callback', () => { + it('should pass prepareStep callback to streamTextIterator', async () => { + const mockModel = createMockModel(); + + const agent = new DurableAgent({ + model: async () => mockModel, + tools: {}, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi.fn().mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as MockIterator + ); + + const prepareStep: PrepareStepCallback = vi.fn().mockReturnValue({}); + + await agent.stream({ + messages: [{ role: 'user', content: 'test' }], + writable: mockWritable, + prepareStep, + }); + + // Verify streamTextIterator was called with prepareStep + expect(streamTextIterator).toHaveBeenCalledWith( + expect.objectContaining({ + prepareStep, + }) + ); + }); + + it('should allow prepareStep to modify messages', async () => { + const mockModel = createMockModel(); + + const agent = new DurableAgent({ + model: async () => mockModel, + tools: {}, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi.fn().mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as MockIterator + ); + + const injectedMessage = { + role: 'user' as const, + content: [{ type: 'text' as const, text: 'injected message' }], + }; + + const prepareStep: PrepareStepCallback = ({ messages }) => { + return { + messages: [...messages, injectedMessage], + }; + }; + + await agent.stream({ + messages: [{ role: 'user', content: 'test' }], + writable: mockWritable, + prepareStep, + }); + + // Verify prepareStep was passed to the iterator + expect(streamTextIterator).toHaveBeenCalledWith( + expect.objectContaining({ + prepareStep: expect.any(Function), + }) + ); + }); + + it('should allow prepareStep to change model dynamically', async () => { + const mockModel = createMockModel(); + + const agent = new DurableAgent({ + model: async () => mockModel, + tools: {}, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi.fn().mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as MockIterator + ); + + const prepareStep: PrepareStepCallback = ({ stepNumber }) => { + // Switch to a different model after step 0 + if (stepNumber > 0) { + return { + model: 'anthropic/claude-sonnet-4.5', + }; + } + return {}; + }; + + await agent.stream({ + messages: [{ role: 'user', content: 'test' }], + writable: mockWritable, + prepareStep, + }); + + // Verify prepareStep was passed to the iterator + expect(streamTextIterator).toHaveBeenCalledWith( + expect.objectContaining({ + prepareStep: expect.any(Function), + }) + ); + }); + + it('should provide step information to prepareStep callback', async () => { + const mockModel = createMockModel(); + + const agent = new DurableAgent({ + model: async () => mockModel, + tools: {}, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi.fn().mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as MockIterator + ); + + const prepareStepCalls: Array<{ + model: unknown; + stepNumber: number; + steps: unknown[]; + messages: LanguageModelV2Prompt; + }> = []; + + const prepareStep: PrepareStepCallback = (info) => { + prepareStepCalls.push({ + model: info.model, + stepNumber: info.stepNumber, + steps: info.steps, + messages: info.messages, + }); + return {}; + }; + + await agent.stream({ + messages: [{ role: 'user', content: 'test' }], + writable: mockWritable, + prepareStep, + }); + + // Verify prepareStep was passed and the function captures expected params + expect(streamTextIterator).toHaveBeenCalledWith( + expect.objectContaining({ + prepareStep: expect.any(Function), + }) + ); + }); + }); + + describe('tool execution with messages', () => { + it('should pass conversation messages to tool execute function', async () => { + // Track what messages were passed to the tool + let receivedMessages: unknown; + let receivedToolCallId: string | undefined; + + const tools: ToolSet = { + testTool: { + description: 'A test tool', + inputSchema: z.object({ query: z.string() }), + execute: async (_input, options) => { + receivedMessages = options.messages; + receivedToolCallId = options.toolCallId; + return { result: 'success' }; + }, + }, + }; + + const mockModel = createMockModel(); + + const agent = new DurableAgent({ + model: async () => mockModel, + tools, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + // Mock conversation messages that would be accumulated by the iterator + const conversationMessages: LanguageModelV2Prompt = [ + { + role: 'user', + content: [{ type: 'text', text: 'What is the weather?' }], + }, + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'test-call-id', + toolName: 'testTool', + input: { query: 'weather' }, + }, + ], + }, + ]; + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + toolCalls: [ + { + toolCallId: 'test-call-id', + toolName: 'testTool', + input: '{"query":"weather"}', + } as LanguageModelV2ToolCall, + ], + messages: conversationMessages, + }, + }) + .mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as MockIterator + ); + + await agent.stream({ + messages: [{ role: 'user', content: 'What is the weather?' }], + writable: mockWritable, + }); + + // Verify that messages were passed to the tool + expect(receivedToolCallId).toBe('test-call-id'); + expect(receivedMessages).toBeDefined(); + expect(Array.isArray(receivedMessages)).toBe(true); + expect(receivedMessages).toEqual(conversationMessages); + }); + + it('should pass messages to multiple tools in parallel execution', async () => { + // Track messages received by each tool + const receivedByTools: Record = {}; + + const tools: ToolSet = { + weatherTool: { + description: 'Get weather', + inputSchema: z.object({ city: z.string() }), + execute: async (_input, options) => { + receivedByTools['weatherTool'] = options.messages; + return { temp: 72 }; + }, + }, + newsTool: { + description: 'Get news', + inputSchema: z.object({ topic: z.string() }), + execute: async (_input, options) => { + receivedByTools['newsTool'] = options.messages; + return { headlines: ['News 1'] }; + }, + }, + }; + + const mockModel = createMockModel(); + + const agent = new DurableAgent({ + model: async () => mockModel, + tools, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + const conversationMessages: LanguageModelV2Prompt = [ + { + role: 'user', + content: [{ type: 'text', text: 'Weather and news please' }], + }, + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'weather-call', + toolName: 'weatherTool', + input: { city: 'NYC' }, + }, + { + type: 'tool-call', + toolCallId: 'news-call', + toolName: 'newsTool', + input: { topic: 'tech' }, + }, + ], + }, + ]; + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + toolCalls: [ + { + toolCallId: 'weather-call', + toolName: 'weatherTool', + input: '{"city":"NYC"}', + } as LanguageModelV2ToolCall, + { + toolCallId: 'news-call', + toolName: 'newsTool', + input: '{"topic":"tech"}', + } as LanguageModelV2ToolCall, + ], + messages: conversationMessages, + }, + }) + .mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as MockIterator + ); + + await agent.stream({ + messages: [{ role: 'user', content: 'Weather and news please' }], + writable: mockWritable, + }); + + // Both tools should have received the same conversation messages + expect(receivedByTools['weatherTool']).toEqual(conversationMessages); + expect(receivedByTools['newsTool']).toEqual(conversationMessages); + }); + + it('should pass updated messages on subsequent tool call rounds', async () => { + // Track messages received in each round + const messagesPerRound: unknown[] = []; + + const tools: ToolSet = { + searchTool: { + description: 'Search for info', + inputSchema: z.object({ query: z.string() }), + execute: async (_input, options) => { + messagesPerRound.push(options.messages); + return { found: true }; + }, + }, + }; + + const mockModel = createMockModel(); + + const agent = new DurableAgent({ + model: async () => mockModel, + tools, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + // First round messages + const firstRoundMessages: LanguageModelV2Prompt = [ + { role: 'user', content: [{ type: 'text', text: 'Search for cats' }] }, + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'search-1', + toolName: 'searchTool', + input: { query: 'cats' }, + }, + ], + }, + ]; + + // Second round messages (includes first tool result) + const secondRoundMessages: LanguageModelV2Prompt = [ + ...firstRoundMessages, + { + role: 'tool', + content: [ + { + type: 'tool-result', + toolCallId: 'search-1', + toolName: 'searchTool', + output: { type: 'text', value: '{"found":true}' }, + }, + ], + }, + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'search-2', + toolName: 'searchTool', + input: { query: 'dogs' }, + }, + ], + }, + ]; + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi + .fn() + // First tool call round + .mockResolvedValueOnce({ + done: false, + value: { + toolCalls: [ + { + toolCallId: 'search-1', + toolName: 'searchTool', + input: '{"query":"cats"}', + } as LanguageModelV2ToolCall, + ], + messages: firstRoundMessages, + }, + }) + // Second tool call round + .mockResolvedValueOnce({ + done: false, + value: { + toolCalls: [ + { + toolCallId: 'search-2', + toolName: 'searchTool', + input: '{"query":"dogs"}', + } as LanguageModelV2ToolCall, + ], + messages: secondRoundMessages, + }, + }) + .mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as MockIterator + ); + + await agent.stream({ + messages: [{ role: 'user', content: 'Search for cats' }], + writable: mockWritable, + }); + + // Verify messages grow with each round + expect(messagesPerRound).toHaveLength(2); + expect(messagesPerRound[0]).toEqual(firstRoundMessages); + expect(messagesPerRound[1]).toEqual(secondRoundMessages); + // Second round should have more messages than first + expect((messagesPerRound[1] as unknown[]).length).toBeGreaterThan( + (messagesPerRound[0] as unknown[]).length + ); + }); + }); }); diff --git a/packages/ai/src/agent/durable-agent.ts b/packages/ai/src/agent/durable-agent.ts index cfe3c50e2..2f05f363a 100644 --- a/packages/ai/src/agent/durable-agent.ts +++ b/packages/ai/src/agent/durable-agent.ts @@ -1,11 +1,13 @@ import type { LanguageModelV2, + LanguageModelV2Prompt, LanguageModelV2ToolCall, LanguageModelV2ToolResultPart, } from '@ai-sdk/provider'; import { asSchema, type ModelMessage, + type StepResult, type StopCondition, type StreamTextOnStepFinishCallback, type ToolSet, @@ -15,6 +17,57 @@ import { convertToLanguageModelPrompt, standardizePrompt } from 'ai/internal'; import { FatalError } from 'workflow'; import { streamTextIterator } from './stream-text-iterator.js'; +/** + * Information passed to the prepareStep callback. + */ +export interface PrepareStepInfo { + /** + * The current model configuration (string or function). + */ + model: string | (() => Promise); + + /** + * The current step number (0-indexed). + */ + stepNumber: number; + + /** + * All previous steps with their results. + */ + steps: StepResult[]; + + /** + * The messages that will be sent to the model. + * This is the LanguageModelV2Prompt format used internally. + */ + messages: LanguageModelV2Prompt; +} + +/** + * Return type from the prepareStep callback. + * All properties are optional - only return the ones you want to override. + */ +export interface PrepareStepResult { + /** + * Override the model for this step. + */ + model?: string | (() => Promise); + + /** + * Override the messages for this step. + * Use this for context management or message injection. + */ + messages?: LanguageModelV2Prompt; +} + +/** + * Callback function called before each step in the agent loop. + * Use this to modify settings, manage context, or implement dynamic behavior. + */ +export type PrepareStepCallback = ( + info: PrepareStepInfo +) => PrepareStepResult | Promise; + /** * Configuration options for creating a {@link DurableAgent} instance. */ @@ -43,7 +96,7 @@ export interface DurableAgentOptions { /** * Options for the {@link DurableAgent.stream} method. */ -export interface DurableAgentStreamOptions { +export interface DurableAgentStreamOptions { /** * The conversation messages to process. Should follow the AI SDK's ModelMessage format. */ @@ -89,6 +142,26 @@ export interface DurableAgentStreamOptions { * Callback function to be called after each step completes. */ onStepFinish?: StreamTextOnStepFinishCallback; + + /** + * Callback function called before each step in the agent loop. + * Use this to modify settings, manage context, or inject messages dynamically. + * + * @example + * ```typescript + * prepareStep: async ({ messages, stepNumber }) => { + * // Inject messages from a queue + * const queuedMessages = await getQueuedMessages(); + * if (queuedMessages.length > 0) { + * return { + * messages: [...messages, ...queuedMessages], + * }; + * } + * return {}; + * } + * ``` + */ + prepareStep?: PrepareStepCallback; } /** @@ -134,7 +207,9 @@ export class DurableAgent { throw new Error('Not implemented'); } - async stream(options: DurableAgentStreamOptions) { + async stream( + options: DurableAgentStreamOptions + ) { const prompt = await standardizePrompt({ system: options.system || this.system, messages: options.messages, @@ -154,15 +229,16 @@ export class DurableAgent { stopConditions: options.stopWhen, sendStart: options.sendStart ?? true, onStepFinish: options.onStepFinish, + prepareStep: options.prepareStep, }); let result = await iterator.next(); while (!result.done) { - const toolCalls = result.value; + const { toolCalls, messages } = result.value; const toolResults = await Promise.all( toolCalls.map( (toolCall): Promise => - executeTool(toolCall, this.tools) + executeTool(toolCall, this.tools, messages) ) ); result = await iterator.next(toolResults); @@ -209,7 +285,8 @@ async function closeStream( async function executeTool( toolCall: LanguageModelV2ToolCall, - tools: ToolSet + tools: ToolSet, + messages: LanguageModelV2Prompt ): Promise { const tool = tools[toolCall.toolName]; if (!tool) throw new Error(`Tool "${toolCall.toolName}" not found`); @@ -228,8 +305,8 @@ async function executeTool( try { const toolResult = await tool.execute(input.value, { toolCallId: toolCall.toolCallId, - // TODO: pass the proper messages to the tool (we'd need to pass them through the iterator) - messages: [], + // Pass the conversation messages to the tool so it has context about the conversation + messages, }); return { diff --git a/packages/ai/src/agent/stream-text-iterator.ts b/packages/ai/src/agent/stream-text-iterator.ts index db137ec88..cbe808760 100644 --- a/packages/ai/src/agent/stream-text-iterator.ts +++ b/packages/ai/src/agent/stream-text-iterator.ts @@ -11,8 +11,20 @@ import type { UIMessageChunk, } from 'ai'; import { doStreamStep, type ModelStopCondition } from './do-stream-step.js'; +import type { PrepareStepCallback } from './durable-agent.js'; import { toolsToModelTools } from './tools-to-model-tools.js'; +/** + * The value yielded by the stream text iterator when tool calls are requested. + * Contains both the tool calls and the current conversation messages. + */ +export interface StreamTextIteratorYieldValue { + /** The tool calls requested by the model */ + toolCalls: LanguageModelV2ToolCall[]; + /** The conversation messages up to (and including) the tool call request */ + messages: LanguageModelV2Prompt; +} + // This runs in the workflow context export async function* streamTextIterator({ prompt, @@ -22,6 +34,7 @@ export async function* streamTextIterator({ stopConditions, sendStart = true, onStepFinish, + prepareStep, }: { prompt: LanguageModelV2Prompt; tools: ToolSet; @@ -30,21 +43,42 @@ export async function* streamTextIterator({ stopConditions?: ModelStopCondition[] | ModelStopCondition; sendStart?: boolean; onStepFinish?: StreamTextOnStepFinishCallback; + prepareStep?: PrepareStepCallback; }): AsyncGenerator< - LanguageModelV2ToolCall[], + StreamTextIteratorYieldValue, LanguageModelV2Prompt, LanguageModelV2ToolResultPart[] > { - const conversationPrompt = [...prompt]; // Create a mutable copy + let conversationPrompt = [...prompt]; // Create a mutable copy + let currentModel = model; const steps: StepResult[] = []; let done = false; let isFirstIteration = true; + let stepNumber = 0; while (!done) { + // Call prepareStep callback before each step if provided + if (prepareStep) { + const prepareResult = await prepareStep({ + model: currentModel, + stepNumber, + steps, + messages: conversationPrompt, + }); + + // Apply any overrides from prepareStep + if (prepareResult.model !== undefined) { + currentModel = prepareResult.model; + } + if (prepareResult.messages !== undefined) { + conversationPrompt = [...prepareResult.messages]; + } + } + const { toolCalls, finish, step } = await doStreamStep( conversationPrompt, - model, + currentModel, writable, toolsToModelTools(tools), { @@ -52,6 +86,7 @@ export async function* streamTextIterator({ } ); isFirstIteration = false; + stepNumber++; steps.push(step); if (finish?.finishReason === 'tool-calls') { @@ -66,8 +101,9 @@ export async function* streamTextIterator({ })), }); - // Yield the tool calls and wait for results - const toolResults = yield toolCalls; + // Yield the tool calls along with the current conversation messages + // This allows executeTool to pass the conversation context to tool execute functions + const toolResults = yield { toolCalls, messages: conversationPrompt }; await writeToolOutputToUI(writable, toolResults);