diff --git a/.changeset/short-seas-flash.md b/.changeset/short-seas-flash.md new file mode 100644 index 00000000000..4186e432753 --- /dev/null +++ b/.changeset/short-seas-flash.md @@ -0,0 +1,11 @@ +--- +'@ai-sdk/provider-utils': patch +'@ai-sdk/anthropic': patch +'@ai-sdk/provider': patch +'@ai-sdk/mistral': patch +'@ai-sdk/google': patch +'@ai-sdk/openai': patch +'ai': patch +--- + +ai/core: add support for getting raw response headers. diff --git a/examples/ai-core/src/stream-text/openai-response-headers.ts b/examples/ai-core/src/stream-text/openai-response-headers.ts new file mode 100644 index 00000000000..ab4cb1b71db --- /dev/null +++ b/examples/ai-core/src/stream-text/openai-response-headers.ts @@ -0,0 +1,24 @@ +import { openai } from '@ai-sdk/openai'; +import { experimental_streamText } from 'ai'; +import dotenv from 'dotenv'; + +dotenv.config(); + +async function main() { + const result = await experimental_streamText({ + model: openai('gpt-3.5-turbo'), + maxTokens: 512, + temperature: 0.3, + maxRetries: 5, + prompt: 'Invent a new holiday and describe its traditions.', + }); + + console.log(`Request ID: ${result.rawResponse?.headers?.['x-request-id']}`); + console.log(); + + for await (const textPart of result.textStream) { + process.stdout.write(textPart); + } +} + +main().catch(console.error); diff --git a/packages/anthropic/src/anthropic-messages-language-model.test.ts b/packages/anthropic/src/anthropic-messages-language-model.test.ts index 6b7aaef1e6c..3beb86411c2 100644 --- a/packages/anthropic/src/anthropic-messages-language-model.test.ts +++ b/packages/anthropic/src/anthropic-messages-language-model.test.ts @@ -11,10 +11,7 @@ const TEST_PROMPT: LanguageModelV1Prompt = [ { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, ]; -const provider = createAnthropic({ - apiKey: 'test-api-key', -}); - +const provider = createAnthropic({ apiKey: 'test-api-key' }); const model = provider.chat('claude-3-haiku-20240307'); describe('doGenerate', () => { @@ -181,6 +178,28 @@ describe('doGenerate', () => { }); }); + it('should expose the raw response headers', async () => { + prepareJsonResponse({}); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-type': 'application/json', + + // custom header + 'test-header': 'test-value', + }); + }); + it('should pass the model and the messages', async () => { prepareJsonResponse({}); @@ -279,6 +298,30 @@ describe('doStream', () => { ]); }); + it('should expose the raw response headers', async () => { + prepareStreamResponse({ content: [] }); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-type': 'text/event-stream', + 'cache-control': 'no-cache', + connection: 'keep-alive', + + // custom header + 'test-header': 'test-value', + }); + }); + it('should pass the messages and the model', async () => { prepareStreamResponse({ content: [] }); diff --git a/packages/anthropic/src/anthropic-messages-language-model.ts b/packages/anthropic/src/anthropic-messages-language-model.ts index 5e0bb4082ad..423c670cb43 100644 --- a/packages/anthropic/src/anthropic-messages-language-model.ts +++ b/packages/anthropic/src/anthropic-messages-language-model.ts @@ -164,7 +164,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 { ): Promise>> { const { args, warnings } = this.getArgs(options); - const response = await postJsonToApi({ + const { responseHeaders, value: response } = await postJsonToApi({ url: `${this.config.baseURL}/messages`, headers: this.config.headers(), body: args, @@ -210,6 +210,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 { completionTokens: response.usage.output_tokens, }, rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, warnings, }; } @@ -219,7 +220,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 { ): Promise>> { const { args, warnings } = this.getArgs(options); - const response = await postJsonToApi({ + const { responseHeaders, value: response } = await postJsonToApi({ url: `${this.config.baseURL}/messages`, headers: this.config.headers(), body: { @@ -296,6 +297,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 { }), ), rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, warnings, }; } diff --git a/packages/core/core/generate-object/generate-object.ts b/packages/core/core/generate-object/generate-object.ts index bd50eb0b22e..37dec254d21 100644 --- a/packages/core/core/generate-object/generate-object.ts +++ b/packages/core/core/generate-object/generate-object.ts @@ -94,6 +94,7 @@ Default and recommended: 'auto' (best mode for the model). let finishReason: LanguageModelV1FinishReason; let usage: Parameters[0]; let warnings: LanguageModelV1CallWarning[] | undefined; + let rawResponse: { headers?: Record } | undefined; let logprobs: LanguageModelV1LogProbs | undefined; switch (mode) { @@ -122,6 +123,7 @@ Default and recommended: 'auto' (best mode for the model). finishReason = generateResult.finishReason; usage = generateResult.usage; warnings = generateResult.warnings; + rawResponse = generateResult.rawResponse; logprobs = generateResult.logprobs; break; @@ -152,6 +154,7 @@ Default and recommended: 'auto' (best mode for the model). finishReason = generateResult.finishReason; usage = generateResult.usage; warnings = generateResult.warnings; + rawResponse = generateResult.rawResponse; logprobs = generateResult.logprobs; break; @@ -192,6 +195,7 @@ Default and recommended: 'auto' (best mode for the model). finishReason = generateResult.finishReason; usage = generateResult.usage; warnings = generateResult.warnings; + rawResponse = generateResult.rawResponse; logprobs = generateResult.logprobs; break; @@ -218,6 +222,7 @@ Default and recommended: 'auto' (best mode for the model). finishReason, usage: calculateTokenUsage(usage), warnings, + rawResponse, logprobs, }); } @@ -246,6 +251,16 @@ Warnings from the model provider (e.g. unsupported settings) */ readonly warnings: LanguageModelV1CallWarning[] | undefined; + /** +Optional raw response data. + */ + rawResponse?: { + /** +Response headers. + */ + headers?: Record; + }; + /** Logprobs for the completion. `undefined` if the mode does not support logprobs or if was not enabled @@ -257,12 +272,16 @@ Logprobs for the completion. finishReason: LanguageModelV1FinishReason; usage: TokenUsage; warnings: LanguageModelV1CallWarning[] | undefined; + rawResponse?: { + headers?: Record; + }; logprobs: LanguageModelV1LogProbs | undefined; }) { this.object = options.object; this.finishReason = options.finishReason; this.usage = options.usage; this.warnings = options.warnings; + this.rawResponse = options.rawResponse; this.logprobs = options.logprobs; } } diff --git a/packages/core/core/generate-object/stream-object.ts b/packages/core/core/generate-object/stream-object.ts index 0a8037f29a3..f08f9366919 100644 --- a/packages/core/core/generate-object/stream-object.ts +++ b/packages/core/core/generate-object/stream-object.ts @@ -220,6 +220,7 @@ Default and recommended: 'auto' (best mode for the model). return new StreamObjectResult({ stream: result.stream.pipeThrough(new TransformStream(transformer)), warnings: result.warnings, + rawResponse: result.rawResponse, }); } @@ -259,15 +260,30 @@ Warnings from the model provider (e.g. unsupported settings) */ readonly warnings: LanguageModelV1CallWarning[] | undefined; + /** +Optional raw response data. + */ + rawResponse?: { + /** +Response headers. + */ + headers?: Record; + }; + constructor({ stream, warnings, + rawResponse, }: { stream: ReadableStream; warnings: LanguageModelV1CallWarning[] | undefined; + rawResponse?: { + headers?: Record; + }; }) { this.originalStream = stream; this.warnings = warnings; + this.rawResponse = rawResponse; } get partialObjectStream(): AsyncIterableStream> { diff --git a/packages/core/core/generate-text/generate-text.ts b/packages/core/core/generate-text/generate-text.ts index 345b5944ecf..015a2200d8f 100644 --- a/packages/core/core/generate-text/generate-text.ts +++ b/packages/core/core/generate-text/generate-text.ts @@ -116,6 +116,7 @@ The tools that the model can call. The model needs to support calling tools. finishReason: modelResponse.finishReason, usage: calculateTokenUsage(modelResponse.usage), warnings: modelResponse.warnings, + rawResponse: modelResponse.rawResponse, logprobs: modelResponse.logprobs, }); } @@ -188,6 +189,16 @@ Warnings from the model provider (e.g. unsupported settings) */ readonly warnings: LanguageModelV1CallWarning[] | undefined; + /** +Optional raw response data. + */ + rawResponse?: { + /** +Response headers. + */ + headers?: Record; + }; + /** Logprobs for the completion. `undefined` if the mode does not support logprobs or if was not enabled @@ -201,6 +212,9 @@ Logprobs for the completion. finishReason: LanguageModelV1FinishReason; usage: TokenUsage; warnings: LanguageModelV1CallWarning[] | undefined; + rawResponse?: { + headers?: Record; + }; logprobs: LanguageModelV1LogProbs | undefined; }) { this.text = options.text; @@ -209,6 +223,7 @@ Logprobs for the completion. this.finishReason = options.finishReason; this.usage = options.usage; this.warnings = options.warnings; + this.rawResponse = options.rawResponse; this.logprobs = options.logprobs; } } diff --git a/packages/core/core/generate-text/stream-text.test.ts b/packages/core/core/generate-text/stream-text.test.ts index a9628def5a1..0ef3e5a6b78 100644 --- a/packages/core/core/generate-text/stream-text.test.ts +++ b/packages/core/core/generate-text/stream-text.test.ts @@ -4,9 +4,8 @@ import { convertArrayToReadableStream } from '../test/convert-array-to-readable- import { convertAsyncIterableToArray } from '../test/convert-async-iterable-to-array'; import { convertReadableStreamToArray } from '../test/convert-readable-stream-to-array'; import { MockLanguageModelV1 } from '../test/mock-language-model-v1'; -import { experimental_streamText } from './stream-text'; -import { ServerResponse } from 'node:http'; import { createMockServerResponse } from '../test/mock-server-response'; +import { experimental_streamText } from './stream-text'; describe('result.textStream', () => { it('should send text deltas', async () => { diff --git a/packages/core/core/generate-text/stream-text.ts b/packages/core/core/generate-text/stream-text.ts index 68497fa8383..7275c9cafe8 100644 --- a/packages/core/core/generate-text/stream-text.ts +++ b/packages/core/core/generate-text/stream-text.ts @@ -85,7 +85,7 @@ The tools that the model can call. The model needs to support calling tools. }): Promise> { const retry = retryWithExponentialBackoff({ maxRetries }); const validatedPrompt = getValidatedPrompt({ system, prompt, messages }); - const { stream, warnings } = await retry(() => + const { stream, warnings, rawResponse } = await retry(() => model.doStream({ mode: { type: 'regular', @@ -112,6 +112,7 @@ The tools that the model can call. The model needs to support calling tools. generatorStream: stream, }), warnings, + rawResponse, }); } @@ -152,15 +153,30 @@ Warnings from the model provider (e.g. unsupported settings) */ readonly warnings: LanguageModelV1CallWarning[] | undefined; + /** +Optional raw response data. + */ + rawResponse?: { + /** +Response headers. + */ + headers?: Record; + }; + constructor({ stream, warnings, + rawResponse, }: { stream: ReadableStream>; warnings: LanguageModelV1CallWarning[] | undefined; + rawResponse?: { + headers?: Record; + }; }) { this.originalStream = stream; this.warnings = warnings; + this.rawResponse = rawResponse; } /** diff --git a/packages/google/src/google-generative-ai-language-model.test.ts b/packages/google/src/google-generative-ai-language-model.test.ts index 222093b271a..a6d90c3e097 100644 --- a/packages/google/src/google-generative-ai-language-model.test.ts +++ b/packages/google/src/google-generative-ai-language-model.test.ts @@ -127,6 +127,28 @@ describe('doGenerate', () => { expect(finishReason).toStrictEqual('tool-calls'); }); + it('should expose the raw response headers', async () => { + prepareJsonResponse({ content: '' }); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-type': 'application/json', + + // custom header + 'test-header': 'test-value', + }); + }); + it('should pass the model and the messages', async () => { prepareJsonResponse({ content: '' }); @@ -225,6 +247,30 @@ describe('doStream', () => { ]); }); + it('should expose the raw response headers', async () => { + prepareStreamResponse({ content: [] }); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-type': 'text/event-stream', + 'cache-control': 'no-cache', + connection: 'keep-alive', + + // custom header + 'test-header': 'test-value', + }); + }); + it('should pass the messages', async () => { prepareStreamResponse({ content: [''] }); diff --git a/packages/google/src/google-generative-ai-language-model.ts b/packages/google/src/google-generative-ai-language-model.ts index 3f13ce5457d..faedcf3d415 100644 --- a/packages/google/src/google-generative-ai-language-model.ts +++ b/packages/google/src/google-generative-ai-language-model.ts @@ -151,7 +151,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 { ): Promise>> { const { args, warnings } = this.getArgs(options); - const response = await postJsonToApi({ + const { responseHeaders, value: response } = await postJsonToApi({ url: `${this.config.baseURL}/${this.modelId}:generateContent`, headers: this.config.headers(), body: args, @@ -180,6 +180,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 { completionTokens: candidate.tokenCount ?? NaN, }, rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, warnings, }; } @@ -189,7 +190,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 { ): Promise>> { const { args, warnings } = this.getArgs(options); - const response = await postJsonToApi({ + const { responseHeaders, value: response } = await postJsonToApi({ url: `${this.config.baseURL}/${this.modelId}:streamGenerateContent?alt=sse`, headers: this.config.headers(), body: args, @@ -287,6 +288,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 { }), ), rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, warnings, }; } diff --git a/packages/mistral/src/mistral-chat-language-model.test.ts b/packages/mistral/src/mistral-chat-language-model.test.ts index 3fccb244291..f3dc23ddee4 100644 --- a/packages/mistral/src/mistral-chat-language-model.test.ts +++ b/packages/mistral/src/mistral-chat-language-model.test.ts @@ -11,6 +11,7 @@ const TEST_PROMPT: LanguageModelV1Prompt = [ ]; const provider = createMistral({ apiKey: 'test-api-key' }); +const model = provider.chat('mistral-small-latest'); describe('doGenerate', () => { const server = new JsonTestServer( @@ -58,7 +59,7 @@ describe('doGenerate', () => { it('should extract text response', async () => { prepareJsonResponse({ content: 'Hello, World!' }); - const { text } = await provider.chat('mistral-small-latest').doGenerate({ + const { text } = await model.doGenerate({ inputFormat: 'prompt', mode: { type: 'regular' }, prompt: TEST_PROMPT, @@ -73,7 +74,7 @@ describe('doGenerate', () => { usage: { prompt_tokens: 20, total_tokens: 25, completion_tokens: 5 }, }); - const { usage } = await provider.chat('mistral-small-latest').doGenerate({ + const { usage } = await model.doGenerate({ inputFormat: 'prompt', mode: { type: 'regular' }, prompt: TEST_PROMPT, @@ -85,10 +86,32 @@ describe('doGenerate', () => { }); }); + it('should expose the raw response headers', async () => { + prepareJsonResponse({ content: '' }); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-type': 'application/json', + + // custom header + 'test-header': 'test-value', + }); + }); + it('should pass the model and the messages', async () => { prepareJsonResponse({ content: '' }); - await provider.chat('mistral-small-latest').doGenerate({ + await model.doGenerate({ inputFormat: 'prompt', mode: { type: 'regular' }, prompt: TEST_PROMPT, @@ -167,7 +190,7 @@ describe('doStream', () => { it('should stream text deltas', async () => { prepareStreamResponse({ content: ['Hello', ', ', 'world!'] }); - const { stream } = await provider.chat('mistral-small-latest').doStream({ + const { stream } = await model.doStream({ inputFormat: 'prompt', mode: { type: 'regular' }, prompt: TEST_PROMPT, @@ -251,10 +274,34 @@ describe('doStream', () => { ]); }); + it('should expose the raw response headers', async () => { + prepareStreamResponse({ content: [] }); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-type': 'text/event-stream', + 'cache-control': 'no-cache', + connection: 'keep-alive', + + // custom header + 'test-header': 'test-value', + }); + }); + it('should pass the messages', async () => { prepareStreamResponse({ content: [''] }); - await provider.chat('mistral-small-latest').doStream({ + await model.doStream({ inputFormat: 'prompt', mode: { type: 'regular' }, prompt: TEST_PROMPT, diff --git a/packages/mistral/src/mistral-chat-language-model.ts b/packages/mistral/src/mistral-chat-language-model.ts index 84cae9eb5d6..a553a3f3e93 100644 --- a/packages/mistral/src/mistral-chat-language-model.ts +++ b/packages/mistral/src/mistral-chat-language-model.ts @@ -155,7 +155,7 @@ export class MistralChatLanguageModel implements LanguageModelV1 { ): Promise>> { const { args, warnings } = this.getArgs(options); - const response = await postJsonToApi({ + const { responseHeaders, value: response } = await postJsonToApi({ url: `${this.config.baseURL}/chat/completions`, headers: this.config.headers(), body: args, @@ -183,6 +183,7 @@ export class MistralChatLanguageModel implements LanguageModelV1 { completionTokens: response.usage.completion_tokens, }, rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, warnings, }; } @@ -192,7 +193,7 @@ export class MistralChatLanguageModel implements LanguageModelV1 { ): Promise>> { const { args, warnings } = this.getArgs(options); - const response = await postJsonToApi({ + const { responseHeaders, value: response } = await postJsonToApi({ url: `${this.config.baseURL}/chat/completions`, headers: this.config.headers(), body: { @@ -287,6 +288,7 @@ export class MistralChatLanguageModel implements LanguageModelV1 { }), ), rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, warnings, }; } diff --git a/packages/openai/src/openai-chat-language-model.test.ts b/packages/openai/src/openai-chat-language-model.test.ts index 41f61d81e3f..8c1825168ed 100644 --- a/packages/openai/src/openai-chat-language-model.test.ts +++ b/packages/openai/src/openai-chat-language-model.test.ts @@ -106,9 +106,8 @@ const TEST_LOGPROBS = { ], }; -const provider = createOpenAI({ - apiKey: 'test-api-key', -}); +const provider = createOpenAI({ apiKey: 'test-api-key' }); +const model = provider.chat('gpt-3.5-turbo'); describe('doGenerate', () => { const server = new JsonTestServer( @@ -168,7 +167,7 @@ describe('doGenerate', () => { it('should extract text response', async () => { prepareJsonResponse({ content: 'Hello, World!' }); - const { text } = await provider.chat('gpt-3.5-turbo').doGenerate({ + const { text } = await model.doGenerate({ inputFormat: 'prompt', mode: { type: 'regular' }, prompt: TEST_PROMPT, @@ -183,7 +182,7 @@ describe('doGenerate', () => { usage: { prompt_tokens: 20, total_tokens: 25, completion_tokens: 5 }, }); - const { usage } = await provider.chat('gpt-3.5-turbo').doGenerate({ + const { usage } = await model.doGenerate({ inputFormat: 'prompt', mode: { type: 'regular' }, prompt: TEST_PROMPT, @@ -200,8 +199,6 @@ describe('doGenerate', () => { logprobs: TEST_LOGPROBS, }); - const provider = createOpenAI({ apiKey: 'test-api-key' }); - const response = await provider .chat('gpt-3.5-turbo', { logprobs: 1 }) .doGenerate({ @@ -220,20 +217,41 @@ describe('doGenerate', () => { finish_reason: 'stop', }); - const provider = createOpenAI({ apiKey: 'test-api-key' }); - - const response = await provider.chat('gpt-3.5-turbo').doGenerate({ + const response = await model.doGenerate({ inputFormat: 'prompt', mode: { type: 'regular' }, prompt: TEST_PROMPT, }); + expect(response.finishReason).toStrictEqual('stop'); }); + it('should expose the raw response headers', async () => { + prepareJsonResponse({ content: '' }); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-type': 'application/json', + + // custom header + 'test-header': 'test-value', + }); + }); + it('should pass the model and the messages', async () => { prepareJsonResponse({ content: '' }); - await provider.chat('gpt-3.5-turbo').doGenerate({ + await model.doGenerate({ inputFormat: 'prompt', mode: { type: 'regular' }, prompt: TEST_PROMPT, @@ -357,7 +375,7 @@ describe('doStream', () => { logprobs: TEST_LOGPROBS, }); - const { stream } = await provider.chat('gpt-3.5-turbo').doStream({ + const { stream } = await model.doStream({ inputFormat: 'prompt', mode: { type: 'regular' }, prompt: TEST_PROMPT, @@ -412,7 +430,7 @@ describe('doStream', () => { 'data: [DONE]\n\n', ]; - const { stream } = await provider.chat('gpt-3.5-turbo').doStream({ + const { stream } = await model.doStream({ inputFormat: 'prompt', mode: { type: 'regular', @@ -499,10 +517,34 @@ describe('doStream', () => { ]); }); + it('should expose the raw response headers', async () => { + prepareStreamResponse({ content: [] }); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-type': 'text/event-stream', + 'cache-control': 'no-cache', + connection: 'keep-alive', + + // custom header + 'test-header': 'test-value', + }); + }); + it('should pass the messages and the model', async () => { prepareStreamResponse({ content: [] }); - await provider.chat('gpt-3.5-turbo').doStream({ + await model.doStream({ inputFormat: 'prompt', mode: { type: 'regular' }, prompt: TEST_PROMPT, diff --git a/packages/openai/src/openai-chat-language-model.ts b/packages/openai/src/openai-chat-language-model.ts index d96fcb9941d..c6a4e770254 100644 --- a/packages/openai/src/openai-chat-language-model.ts +++ b/packages/openai/src/openai-chat-language-model.ts @@ -153,7 +153,7 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { ): Promise>> { const args = this.getArgs(options); - const response = await postJsonToApi({ + const { responseHeaders, value: response } = await postJsonToApi({ url: `${this.config.baseURL}/chat/completions`, headers: this.config.headers(), body: args, @@ -181,6 +181,7 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { completionTokens: response.usage.completion_tokens, }, rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, warnings: [], logprobs: mapOpenAIChatLogProbsOutput(choice.logprobs), }; @@ -191,7 +192,7 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { ): Promise>> { const args = this.getArgs(options); - const response = await postJsonToApi({ + const { responseHeaders, value: response } = await postJsonToApi({ url: `${this.config.baseURL}/chat/completions`, headers: this.config.headers(), body: { @@ -358,6 +359,7 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { }), ), rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, warnings: [], }; } diff --git a/packages/openai/src/openai-completion-language-model.test.ts b/packages/openai/src/openai-completion-language-model.test.ts index b92eb6f7531..0364f3a6ede 100644 --- a/packages/openai/src/openai-completion-language-model.test.ts +++ b/packages/openai/src/openai-completion-language-model.test.ts @@ -38,9 +38,8 @@ const TEST_LOGPROBS = { ] as Record[], }; -const provider = createOpenAI({ - apiKey: 'test-api-key', -}); +const provider = createOpenAI({ apiKey: 'test-api-key' }); +const model = provider.completion('gpt-3.5-turbo-instruct'); describe('doGenerate', () => { const server = new JsonTestServer('https://api.openai.com/v1/completions'); @@ -90,13 +89,11 @@ describe('doGenerate', () => { it('should extract text response', async () => { prepareJsonResponse({ content: 'Hello, World!' }); - const { text } = await provider - .completion('gpt-3.5-turbo-instruct') - .doGenerate({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - }); + const { text } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); expect(text).toStrictEqual('Hello, World!'); }); @@ -107,13 +104,11 @@ describe('doGenerate', () => { usage: { prompt_tokens: 20, total_tokens: 25, completion_tokens: 5 }, }); - const { usage } = await provider - .completion('gpt-3.5-turbo-instruct') - .doGenerate({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - }); + const { usage } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); expect(usage).toStrictEqual({ promptTokens: 20, @@ -155,10 +150,32 @@ describe('doGenerate', () => { expect(finishReason).toStrictEqual('stop'); }); + it('should expose the raw response headers', async () => { + prepareJsonResponse({ content: '' }); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-type': 'application/json', + + // custom header + 'test-header': 'test-value', + }); + }); + it('should pass the model and the prompt', async () => { prepareJsonResponse({ content: '' }); - await provider.completion('gpt-3.5-turbo-instruct').doGenerate({ + await model.doGenerate({ inputFormat: 'prompt', mode: { type: 'regular' }, prompt: TEST_PROMPT, @@ -275,13 +292,11 @@ describe('doStream', () => { logprobs: TEST_LOGPROBS, }); - const { stream } = await provider - .completion('gpt-3.5-turbo-instruct') - .doStream({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - }); + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); // note: space moved to last chunk bc of trimming expect(await convertStreamToArray(stream)).toStrictEqual([ @@ -298,10 +313,34 @@ describe('doStream', () => { ]); }); + it('should expose the raw response headers', async () => { + prepareStreamResponse({ content: [] }); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-type': 'text/event-stream', + 'cache-control': 'no-cache', + connection: 'keep-alive', + + // custom header + 'test-header': 'test-value', + }); + }); + it('should pass the model and the prompt', async () => { prepareStreamResponse({ content: [] }); - await provider.completion('gpt-3.5-turbo-instruct').doStream({ + await model.doStream({ inputFormat: 'prompt', mode: { type: 'regular' }, prompt: TEST_PROMPT, diff --git a/packages/openai/src/openai-completion-language-model.ts b/packages/openai/src/openai-completion-language-model.ts index f033758a12b..e413d6b0ef9 100644 --- a/packages/openai/src/openai-completion-language-model.ts +++ b/packages/openai/src/openai-completion-language-model.ts @@ -99,8 +99,6 @@ export class OpenAICompletionLanguageModel implements LanguageModelV1 { stop: stopSequences, }; - console.log('BASE ARGS LOGS', baseArgs.logprobs); - switch (type) { case 'regular': { if (mode.tools?.length) { @@ -142,7 +140,7 @@ export class OpenAICompletionLanguageModel implements LanguageModelV1 { ): Promise>> { const args = this.getArgs(options); - const response = await postJsonToApi({ + const { responseHeaders, value: response } = await postJsonToApi({ url: `${this.config.baseURL}/completions`, headers: this.config.headers(), body: args, @@ -165,6 +163,7 @@ export class OpenAICompletionLanguageModel implements LanguageModelV1 { finishReason: mapOpenAIFinishReason(choice.finish_reason), logprobs: mapOpenAICompletionLogProbs(choice.logprobs), rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, warnings: [], }; } @@ -174,7 +173,7 @@ export class OpenAICompletionLanguageModel implements LanguageModelV1 { ): Promise>> { const args = this.getArgs(options); - const response = await postJsonToApi({ + const { responseHeaders, value: response } = await postJsonToApi({ url: `${this.config.baseURL}/completions`, headers: this.config.headers(), body: { @@ -251,6 +250,7 @@ export class OpenAICompletionLanguageModel implements LanguageModelV1 { }), ), rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, warnings: [], }; } diff --git a/packages/provider-utils/src/extract-response-headers.ts b/packages/provider-utils/src/extract-response-headers.ts new file mode 100644 index 00000000000..f465a07ff5f --- /dev/null +++ b/packages/provider-utils/src/extract-response-headers.ts @@ -0,0 +1,15 @@ +/** +Extracts the headers from a response object and returns them as a key-value object. + +@param response - The response object to extract headers from. +@returns The headers as a key-value object. +*/ +export function extractResponseHeaders( + response: Response, +): Record { + const headers: Record = {}; + response.headers.forEach((value, key) => { + headers[key] = value; + }); + return headers; +} diff --git a/packages/provider-utils/src/index.ts b/packages/provider-utils/src/index.ts index 7656476a44b..a9a78c74cec 100644 --- a/packages/provider-utils/src/index.ts +++ b/packages/provider-utils/src/index.ts @@ -1,3 +1,4 @@ +export * from './extract-response-headers'; export * from './generate-id'; export * from './get-error-message'; export * from './load-api-key'; diff --git a/packages/provider-utils/src/response-handler.ts b/packages/provider-utils/src/response-handler.ts index 4b9bd28e1c5..63b5b5a8dad 100644 --- a/packages/provider-utils/src/response-handler.ts +++ b/packages/provider-utils/src/response-handler.ts @@ -1,16 +1,20 @@ -import { APICallError, NoResponseBodyError } from '@ai-sdk/provider'; +import { APICallError, EmptyResponseBodyError } from '@ai-sdk/provider'; import { EventSourceParserStream, ParsedEvent, } from 'eventsource-parser/stream'; import { ZodSchema } from 'zod'; +import { extractResponseHeaders } from './extract-response-headers'; import { ParseResult, parseJSON, safeParseJSON } from './parse-json'; export type ResponseHandler = (options: { url: string; requestBodyValues: unknown; response: Response; -}) => PromiseLike; +}) => PromiseLike<{ + value: RETURN_TYPE; + responseHeaders?: Record; +}>; export const createJsonErrorResponseHandler = ({ @@ -24,17 +28,22 @@ export const createJsonErrorResponseHandler = }): ResponseHandler => async ({ response, url, requestBodyValues }) => { const responseBody = await response.text(); + const responseHeaders = extractResponseHeaders(response); // Some providers return an empty response body for some errors: if (responseBody.trim() === '') { - return new APICallError({ - message: response.statusText, - url, - requestBodyValues, - statusCode: response.status, - responseBody, - isRetryable: isRetryable?.(response), - }); + return { + responseHeaders, + value: new APICallError({ + message: response.statusText, + url, + requestBodyValues, + statusCode: response.status, + responseHeaders, + responseBody, + isRetryable: isRetryable?.(response), + }), + }; } // resilient parsing in case the response is not JSON or does not match the schema: @@ -44,24 +53,32 @@ export const createJsonErrorResponseHandler = schema: errorSchema, }); - return new APICallError({ - message: errorToMessage(parsedError), - url, - requestBodyValues, - statusCode: response.status, - responseBody, - data: parsedError, - isRetryable: isRetryable?.(response, parsedError), - }); + return { + responseHeaders, + value: new APICallError({ + message: errorToMessage(parsedError), + url, + requestBodyValues, + statusCode: response.status, + responseHeaders, + responseBody, + data: parsedError, + isRetryable: isRetryable?.(response, parsedError), + }), + }; } catch (parseError) { - return new APICallError({ - message: response.statusText, - url, - requestBodyValues, - statusCode: response.status, - responseBody, - isRetryable: isRetryable?.(response), - }); + return { + responseHeaders, + value: new APICallError({ + message: response.statusText, + url, + requestBodyValues, + statusCode: response.status, + responseHeaders, + responseBody, + isRetryable: isRetryable?.(response), + }), + }; } }; @@ -70,30 +87,35 @@ export const createEventSourceResponseHandler = chunkSchema: ZodSchema, ): ResponseHandler>> => async ({ response }: { response: Response }) => { + const responseHeaders = extractResponseHeaders(response); + if (response.body == null) { - throw new NoResponseBodyError(); + throw new EmptyResponseBodyError({}); } - return response.body - .pipeThrough(new TextDecoderStream()) - .pipeThrough(new EventSourceParserStream()) - .pipeThrough( - new TransformStream>({ - transform({ data }, controller) { - // ignore the 'DONE' event that e.g. OpenAI sends: - if (data === '[DONE]') { - return; - } + return { + responseHeaders, + value: response.body + .pipeThrough(new TextDecoderStream()) + .pipeThrough(new EventSourceParserStream()) + .pipeThrough( + new TransformStream>({ + transform({ data }, controller) { + // ignore the 'DONE' event that e.g. OpenAI sends: + if (data === '[DONE]') { + return; + } - controller.enqueue( - safeParseJSON({ - text: data, - schema: chunkSchema, - }), - ); - }, - }), - ); + controller.enqueue( + safeParseJSON({ + text: data, + schema: chunkSchema, + }), + ); + }, + }), + ), + }; }; export const createJsonResponseHandler = @@ -106,16 +128,22 @@ export const createJsonResponseHandler = schema: responseSchema, }); + const responseHeaders = extractResponseHeaders(response); + if (!parsedResult.success) { throw new APICallError({ message: 'Invalid JSON response', cause: parsedResult.error, statusCode: response.status, + responseHeaders, responseBody, url, requestBodyValues, }); } - return parsedResult.value; + return { + responseHeaders, + value: parsedResult.value, + }; }; diff --git a/packages/provider-utils/src/test/json-test-server.ts b/packages/provider-utils/src/test/json-test-server.ts index 78bc2da6489..98dcdae7f1c 100644 --- a/packages/provider-utils/src/test/json-test-server.ts +++ b/packages/provider-utils/src/test/json-test-server.ts @@ -4,6 +4,7 @@ import { SetupServer, setupServer } from 'msw/node'; export class JsonTestServer { readonly server: SetupServer; + responseHeaders: Record = {}; responseBodyJson: any = {}; request: Request | undefined; @@ -15,7 +16,12 @@ export class JsonTestServer { http.post(url, ({ request }) => { this.request = request; - return HttpResponse.json(responseBodyJson()); + return HttpResponse.json(responseBodyJson(), { + headers: { + 'Content-Type': 'application/json', + ...this.responseHeaders, + }, + }); }), ); } diff --git a/packages/provider-utils/src/test/streaming-test-server.ts b/packages/provider-utils/src/test/streaming-test-server.ts index 2faff5347c6..04fb77677f6 100644 --- a/packages/provider-utils/src/test/streaming-test-server.ts +++ b/packages/provider-utils/src/test/streaming-test-server.ts @@ -4,6 +4,7 @@ import { SetupServer, setupServer } from 'msw/node'; export class StreamingTestServer { readonly server: SetupServer; + responseHeaders: Record = {}; responseChunks: any[] = []; request: Request | undefined; @@ -34,6 +35,7 @@ export class StreamingTestServer { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', Connection: 'keep-alive', + ...this.responseHeaders, }, }); }), diff --git a/packages/provider/src/errors/api-call-error.ts b/packages/provider/src/errors/api-call-error.ts index bebac577257..c1d7cdcbe0b 100644 --- a/packages/provider/src/errors/api-call-error.ts +++ b/packages/provider/src/errors/api-call-error.ts @@ -2,7 +2,10 @@ export class APICallError extends Error { readonly url: string; readonly requestBodyValues: unknown; readonly statusCode?: number; + + readonly responseHeaders?: Record; readonly responseBody?: string; + readonly cause?: unknown; readonly isRetryable: boolean; readonly data?: unknown; @@ -12,6 +15,7 @@ export class APICallError extends Error { url, requestBodyValues, statusCode, + responseHeaders, responseBody, cause, isRetryable = statusCode != null && @@ -25,6 +29,7 @@ export class APICallError extends Error { url: string; requestBodyValues: unknown; statusCode?: number; + responseHeaders?: Record; responseBody?: string; cause?: unknown; isRetryable?: boolean; @@ -37,6 +42,7 @@ export class APICallError extends Error { this.url = url; this.requestBodyValues = requestBodyValues; this.statusCode = statusCode; + this.responseHeaders = responseHeaders; this.responseBody = responseBody; this.cause = cause; this.isRetryable = isRetryable; @@ -51,6 +57,8 @@ export class APICallError extends Error { typeof (error as APICallError).requestBodyValues === 'object' && ((error as APICallError).statusCode == null || typeof (error as APICallError).statusCode === 'number') && + ((error as APICallError).responseHeaders == null || + typeof (error as APICallError).responseHeaders === 'object') && ((error as APICallError).responseBody == null || typeof (error as APICallError).responseBody === 'string') && ((error as APICallError).cause == null || @@ -68,6 +76,7 @@ export class APICallError extends Error { url: this.url, requestBodyValues: this.requestBodyValues, statusCode: this.statusCode, + responseHeaders: this.responseHeaders, responseBody: this.responseBody, cause: this.cause, isRetryable: this.isRetryable, diff --git a/packages/provider/src/errors/empty-response-body-error.ts b/packages/provider/src/errors/empty-response-body-error.ts new file mode 100644 index 00000000000..6f6dfee6398 --- /dev/null +++ b/packages/provider/src/errors/empty-response-body-error.ts @@ -0,0 +1,21 @@ +export class EmptyResponseBodyError extends Error { + constructor({ message = 'Empty response body' }: { message?: string } = {}) { + super(message); + + this.name = 'AI_EmptyResponseBodyError'; + } + + static isEmptyResponseBodyError( + error: unknown, + ): error is EmptyResponseBodyError { + return error instanceof Error && error.name === 'AI_EmptyResponseBodyError'; + } + + toJSON() { + return { + name: this.name, + message: this.message, + stack: this.stack, + }; + } +} diff --git a/packages/provider/src/errors/index.ts b/packages/provider/src/errors/index.ts index a1f5ec0e8cf..ac654681c94 100644 --- a/packages/provider/src/errors/index.ts +++ b/packages/provider/src/errors/index.ts @@ -1,4 +1,5 @@ export * from './api-call-error'; +export * from './empty-response-body-error'; export * from './invalid-argument-error'; export * from './invalid-data-content-error'; export * from './invalid-prompt-error'; @@ -7,7 +8,6 @@ export * from './invalid-tool-arguments-error'; export * from './json-parse-error'; export * from './load-api-key-error'; export * from './no-object-generated-error'; -export * from './no-response-body-error'; export * from './no-such-tool-error'; export * from './retry-error'; export * from './tool-call-parse-error'; diff --git a/packages/provider/src/errors/no-response-body-error.ts b/packages/provider/src/errors/no-response-body-error.ts deleted file mode 100644 index fa431b5fe2c..00000000000 --- a/packages/provider/src/errors/no-response-body-error.ts +++ /dev/null @@ -1,19 +0,0 @@ -export class NoResponseBodyError extends Error { - constructor({ message = 'No response body' }: { message?: string } = {}) { - super(message); - - this.name = 'AI_NoResponseBodyError'; - } - - static isNoResponseBodyError(error: unknown): error is NoResponseBodyError { - return error instanceof Error && error.name === 'AI_NoResponseBodyError'; - } - - toJSON() { - return { - name: this.name, - message: this.message, - stack: this.stack, - }; - } -} diff --git a/packages/provider/src/language-model/v1/language-model-v1.ts b/packages/provider/src/language-model/v1/language-model-v1.ts index 643c179a7a1..4fb3819b393 100644 --- a/packages/provider/src/language-model/v1/language-model-v1.ts +++ b/packages/provider/src/language-model/v1/language-model-v1.ts @@ -87,6 +87,16 @@ export type LanguageModelV1 = { rawSettings: Record; }; + /** + * Optional raw response information for debugging purposes. + */ + rawResponse?: { + /** + * Response headers. + */ + headers?: Record; + }; + warnings?: LanguageModelV1CallWarning[]; /** @@ -124,6 +134,16 @@ export type LanguageModelV1 = { rawSettings: Record; }; + /** + * Optional raw response data. + */ + rawResponse?: { + /** + * Response headers. + */ + headers?: Record; + }; + warnings?: LanguageModelV1CallWarning[]; }>; }; @@ -156,3 +176,5 @@ export type LanguageModelV1StreamPart = // error parts are streamed, allowing for multiple errors | { type: 'error'; error: unknown }; + +export type LanguageModelV1ResponseMetadata = {};