Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ai/core: expose raw response headers #1417

Merged
merged 11 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .changeset/short-seas-flash.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
'@ai-sdk/provider-utils': patch
'@ai-sdk/anthropic': patch
'@ai-sdk/provider': patch
'@ai-sdk/mistral': patch
'@ai-sdk/google': patch
'@ai-sdk/openai': patch
'ai': patch
---

ai/core: add support for getting raw response headers.
24 changes: 24 additions & 0 deletions examples/ai-core/src/stream-text/openai-response-headers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { openai } from '@ai-sdk/openai';
import { experimental_streamText } from 'ai';
import dotenv from 'dotenv';

dotenv.config();

async function main() {
const result = await experimental_streamText({
model: openai('gpt-3.5-turbo'),
maxTokens: 512,
temperature: 0.3,
maxRetries: 5,
prompt: 'Invent a new holiday and describe its traditions.',
});

console.log(`Request ID: ${result.rawResponse?.headers?.['x-request-id']}`);
console.log();

for await (const textPart of result.textStream) {
process.stdout.write(textPart);
}
}

main().catch(console.error);
51 changes: 47 additions & 4 deletions packages/anthropic/src/anthropic-messages-language-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ const TEST_PROMPT: LanguageModelV1Prompt = [
{ role: 'user', content: [{ type: 'text', text: 'Hello' }] },
];

const provider = createAnthropic({
apiKey: 'test-api-key',
});

const provider = createAnthropic({ apiKey: 'test-api-key' });
const model = provider.chat('claude-3-haiku-20240307');

describe('doGenerate', () => {
Expand Down Expand Up @@ -181,6 +178,28 @@ describe('doGenerate', () => {
});
});

it('should expose the raw response headers', async () => {
prepareJsonResponse({});

server.responseHeaders = {
'test-header': 'test-value',
};

const { rawResponse } = await model.doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

expect(rawResponse?.headers).toStrictEqual({
// default headers:
'content-type': 'application/json',

// custom header
'test-header': 'test-value',
});
});

it('should pass the model and the messages', async () => {
prepareJsonResponse({});

Expand Down Expand Up @@ -279,6 +298,30 @@ describe('doStream', () => {
]);
});

it('should expose the raw response headers', async () => {
prepareStreamResponse({ content: [] });

server.responseHeaders = {
'test-header': 'test-value',
};

const { rawResponse } = await model.doStream({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

expect(rawResponse?.headers).toStrictEqual({
// default headers:
'content-type': 'text/event-stream',
'cache-control': 'no-cache',
connection: 'keep-alive',

// custom header
'test-header': 'test-value',
});
});

it('should pass the messages and the model', async () => {
prepareStreamResponse({ content: [] });

Expand Down
6 changes: 4 additions & 2 deletions packages/anthropic/src/anthropic-messages-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> {
const { args, warnings } = this.getArgs(options);

const response = await postJsonToApi({
const { responseHeaders, value: response } = await postJsonToApi({
url: `${this.config.baseURL}/messages`,
headers: this.config.headers(),
body: args,
Expand Down Expand Up @@ -210,6 +210,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
completionTokens: response.usage.output_tokens,
},
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings,
};
}
Expand All @@ -219,7 +220,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
const { args, warnings } = this.getArgs(options);

const response = await postJsonToApi({
const { responseHeaders, value: response } = await postJsonToApi({
url: `${this.config.baseURL}/messages`,
headers: this.config.headers(),
body: {
Expand Down Expand Up @@ -296,6 +297,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
}),
),
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings,
};
}
Expand Down
19 changes: 19 additions & 0 deletions packages/core/core/generate-object/generate-object.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ Default and recommended: 'auto' (best mode for the model).
let finishReason: LanguageModelV1FinishReason;
let usage: Parameters<typeof calculateTokenUsage>[0];
let warnings: LanguageModelV1CallWarning[] | undefined;
let rawResponse: { headers?: Record<string, string> } | undefined;
let logprobs: LanguageModelV1LogProbs | undefined;

switch (mode) {
Expand Down Expand Up @@ -122,6 +123,7 @@ Default and recommended: 'auto' (best mode for the model).
finishReason = generateResult.finishReason;
usage = generateResult.usage;
warnings = generateResult.warnings;
rawResponse = generateResult.rawResponse;
logprobs = generateResult.logprobs;

break;
Expand Down Expand Up @@ -152,6 +154,7 @@ Default and recommended: 'auto' (best mode for the model).
finishReason = generateResult.finishReason;
usage = generateResult.usage;
warnings = generateResult.warnings;
rawResponse = generateResult.rawResponse;
logprobs = generateResult.logprobs;

break;
Expand Down Expand Up @@ -192,6 +195,7 @@ Default and recommended: 'auto' (best mode for the model).
finishReason = generateResult.finishReason;
usage = generateResult.usage;
warnings = generateResult.warnings;
rawResponse = generateResult.rawResponse;
logprobs = generateResult.logprobs;

break;
Expand All @@ -218,6 +222,7 @@ Default and recommended: 'auto' (best mode for the model).
finishReason,
usage: calculateTokenUsage(usage),
warnings,
rawResponse,
logprobs,
});
}
Expand Down Expand Up @@ -246,6 +251,16 @@ Warnings from the model provider (e.g. unsupported settings)
*/
readonly warnings: LanguageModelV1CallWarning[] | undefined;

/**
Optional raw response data.
*/
rawResponse?: {
/**
Response headers.
*/
headers?: Record<string, string>;
};

/**
Logprobs for the completion.
`undefined` if the mode does not support logprobs or if was not enabled
Expand All @@ -257,12 +272,16 @@ Logprobs for the completion.
finishReason: LanguageModelV1FinishReason;
usage: TokenUsage;
warnings: LanguageModelV1CallWarning[] | undefined;
rawResponse?: {
headers?: Record<string, string>;
};
logprobs: LanguageModelV1LogProbs | undefined;
}) {
this.object = options.object;
this.finishReason = options.finishReason;
this.usage = options.usage;
this.warnings = options.warnings;
this.rawResponse = options.rawResponse;
this.logprobs = options.logprobs;
}
}
16 changes: 16 additions & 0 deletions packages/core/core/generate-object/stream-object.ts
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ Default and recommended: 'auto' (best mode for the model).
return new StreamObjectResult({
stream: result.stream.pipeThrough(new TransformStream(transformer)),
warnings: result.warnings,
rawResponse: result.rawResponse,
});
}

Expand Down Expand Up @@ -259,15 +260,30 @@ Warnings from the model provider (e.g. unsupported settings)
*/
readonly warnings: LanguageModelV1CallWarning[] | undefined;

/**
Optional raw response data.
*/
rawResponse?: {
/**
Response headers.
*/
headers?: Record<string, string>;
};

constructor({
stream,
warnings,
rawResponse,
}: {
stream: ReadableStream<string | ObjectStreamPartInput>;
warnings: LanguageModelV1CallWarning[] | undefined;
rawResponse?: {
headers?: Record<string, string>;
};
}) {
this.originalStream = stream;
this.warnings = warnings;
this.rawResponse = rawResponse;
}

get partialObjectStream(): AsyncIterableStream<DeepPartial<T>> {
Expand Down
15 changes: 15 additions & 0 deletions packages/core/core/generate-text/generate-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ The tools that the model can call. The model needs to support calling tools.
finishReason: modelResponse.finishReason,
usage: calculateTokenUsage(modelResponse.usage),
warnings: modelResponse.warnings,
rawResponse: modelResponse.rawResponse,
logprobs: modelResponse.logprobs,
});
}
Expand Down Expand Up @@ -188,6 +189,16 @@ Warnings from the model provider (e.g. unsupported settings)
*/
readonly warnings: LanguageModelV1CallWarning[] | undefined;

/**
Optional raw response data.
*/
rawResponse?: {
/**
Response headers.
*/
headers?: Record<string, string>;
};

/**
Logprobs for the completion.
`undefined` if the mode does not support logprobs or if was not enabled
Expand All @@ -201,6 +212,9 @@ Logprobs for the completion.
finishReason: LanguageModelV1FinishReason;
usage: TokenUsage;
warnings: LanguageModelV1CallWarning[] | undefined;
rawResponse?: {
headers?: Record<string, string>;
};
logprobs: LanguageModelV1LogProbs | undefined;
}) {
this.text = options.text;
Expand All @@ -209,6 +223,7 @@ Logprobs for the completion.
this.finishReason = options.finishReason;
this.usage = options.usage;
this.warnings = options.warnings;
this.rawResponse = options.rawResponse;
this.logprobs = options.logprobs;
}
}
3 changes: 1 addition & 2 deletions packages/core/core/generate-text/stream-text.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ import { convertArrayToReadableStream } from '../test/convert-array-to-readable-
import { convertAsyncIterableToArray } from '../test/convert-async-iterable-to-array';
import { convertReadableStreamToArray } from '../test/convert-readable-stream-to-array';
import { MockLanguageModelV1 } from '../test/mock-language-model-v1';
import { experimental_streamText } from './stream-text';
import { ServerResponse } from 'node:http';
import { createMockServerResponse } from '../test/mock-server-response';
import { experimental_streamText } from './stream-text';

describe('result.textStream', () => {
it('should send text deltas', async () => {
Expand Down
18 changes: 17 additions & 1 deletion packages/core/core/generate-text/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ The tools that the model can call. The model needs to support calling tools.
}): Promise<StreamTextResult<TOOLS>> {
const retry = retryWithExponentialBackoff({ maxRetries });
const validatedPrompt = getValidatedPrompt({ system, prompt, messages });
const { stream, warnings } = await retry(() =>
const { stream, warnings, rawResponse } = await retry(() =>
model.doStream({
mode: {
type: 'regular',
Expand All @@ -112,6 +112,7 @@ The tools that the model can call. The model needs to support calling tools.
generatorStream: stream,
}),
warnings,
rawResponse,
});
}

Expand Down Expand Up @@ -152,15 +153,30 @@ Warnings from the model provider (e.g. unsupported settings)
*/
readonly warnings: LanguageModelV1CallWarning[] | undefined;

/**
Optional raw response data.
*/
rawResponse?: {
/**
Response headers.
*/
headers?: Record<string, string>;
};

constructor({
stream,
warnings,
rawResponse,
}: {
stream: ReadableStream<TextStreamPart<TOOLS>>;
warnings: LanguageModelV1CallWarning[] | undefined;
rawResponse?: {
headers?: Record<string, string>;
};
}) {
this.originalStream = stream;
this.warnings = warnings;
this.rawResponse = rawResponse;
}

/**
Expand Down
Loading
Loading