Skip to content

Commit

Permalink
fix (provider/openai): introduce compatibility mode (#1595)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed May 15, 2024
1 parent 18a9655 commit 4e3c922
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 10 deletions.
5 changes: 5 additions & 0 deletions .changeset/unlucky-trainers-cheer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/openai': patch
---

fix (provider/openai): introduce compatibility mode in which "stream_options" are not sent
6 changes: 6 additions & 0 deletions content/providers/01-ai-sdk-providers/01-openai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ You can use the following optional settings to customize the OpenAI provider ins

Custom headers to include in the requests.

- **compatibility** _"strict" | "compatible"_

OpenAI compatibility mode. Should be set to `strict` when using the OpenAI API,
and `compatible` when using 3rd party providers. In `compatible` mode, newer
information such as streamOptions are not being sent. Defaults to 'compatible'.

## Models

The OpenAI provider instance is a function that you can invoke to create a model:
Expand Down
27 changes: 27 additions & 0 deletions examples/ai-core/src/stream-text/fireworks.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { createOpenAI } from '@ai-sdk/openai';
import { streamText } from 'ai';
import dotenv from 'dotenv';

dotenv.config();

const fireworks = createOpenAI({
apiKey: process.env.FIREWORKS_API_KEY ?? '',
baseURL: 'https://api.fireworks.ai/inference/v1',
});

async function main() {
const result = await streamText({
model: fireworks('accounts/fireworks/models/firefunction-v1'),
prompt: 'Invent a new holiday and describe its traditions.',
});

for await (const textPart of result.textStream) {
process.stdout.write(textPart);
}

console.log();
console.log('Token usage:', await result.usage);
console.log('Finish reason:', await result.finishReason);
}

main().catch(console.error);
6 changes: 5 additions & 1 deletion packages/openai/src/openai-chat-language-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ const TEST_LOGPROBS = {
],
};

const provider = createOpenAI({ apiKey: 'test-api-key' });
const provider = createOpenAI({
apiKey: 'test-api-key',
compatibility: 'strict',
});

const model = provider.chat('gpt-3.5-turbo');

describe('doGenerate', () => {
Expand Down
10 changes: 7 additions & 3 deletions packages/openai/src/openai-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import { mapOpenAIChatLogProbsOutput } from './map-openai-chat-logprobs';
type OpenAIChatConfig = {
provider: string;
baseURL: string;
compatibility: 'strict' | 'compatible';
headers: () => Record<string, string | undefined>;
};

Expand Down Expand Up @@ -198,9 +199,12 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 {
body: {
...args,
stream: true,
stream_options: {
include_usage: true,
},

// only include stream_options when in strict compatibility mode:
stream_options:
this.config.compatibility === 'strict'
? { include_usage: true }
: undefined,
},
failedResponseHandler: openaiFailedResponseHandler,
successfulResponseHandler: createEventSourceResponseHandler(
Expand Down
6 changes: 5 additions & 1 deletion packages/openai/src/openai-completion-language-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ const TEST_LOGPROBS = {
] as Record<string, number>[],
};

const provider = createOpenAI({ apiKey: 'test-api-key' });
const provider = createOpenAI({
apiKey: 'test-api-key',
compatibility: 'strict',
});

const model = provider.completion('gpt-3.5-turbo-instruct');

describe('doGenerate', () => {
Expand Down
10 changes: 7 additions & 3 deletions packages/openai/src/openai-completion-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import { mapOpenAICompletionLogProbs } from './map-openai-completion-logprobs';
type OpenAICompletionConfig = {
provider: string;
baseURL: string;
compatibility: 'strict' | 'compatible';
headers: () => Record<string, string | undefined>;
};

Expand Down Expand Up @@ -179,9 +180,12 @@ export class OpenAICompletionLanguageModel implements LanguageModelV1 {
body: {
...this.getArgs(options),
stream: true,
stream_options: {
include_usage: true,
},

// only include stream_options when in strict compatibility mode:
stream_options:
this.config.compatibility === 'strict'
? { include_usage: true }
: undefined,
},
failedResponseHandler: openaiFailedResponseHandler,
successfulResponseHandler: createEventSourceResponseHandler(
Expand Down
2 changes: 2 additions & 0 deletions packages/openai/src/openai-facade.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Custom headers to include in the requests.
return new OpenAIChatLanguageModel(modelId, settings, {
provider: 'openai.chat',
...this.baseConfig,
compatibility: 'strict',
});
}

Expand All @@ -83,6 +84,7 @@ Custom headers to include in the requests.
return new OpenAICompletionLanguageModel(modelId, settings, {
provider: 'openai.completion',
...this.baseConfig,
compatibility: 'strict',
});
}
}
18 changes: 16 additions & 2 deletions packages/openai/src/openai-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ OpenAI project.
Custom headers to include in the requests.
*/
headers?: Record<string, string>;

/**
OpenAI compatibility mode. Should be set to `strict` when using the OpenAI API,
and `compatible` when using 3rd party providers. In `compatible` mode, newer
information such as streamOptions are not being sent. Defaults to 'compatible'.
*/
compatibility?: 'strict' | 'compatible';
}

/**
Expand All @@ -89,6 +96,9 @@ export function createOpenAI(
withoutTrailingSlash(options.baseURL ?? options.baseUrl) ??
'https://api.openai.com/v1';

// we default to compatible, because strict breaks providers like Groq:
const compatibility = options.compatibility ?? 'compatible';

const getHeaders = () => ({
Authorization: `Bearer ${loadApiKey({
apiKey: options.apiKey,
Expand All @@ -108,6 +118,7 @@ export function createOpenAI(
provider: 'openai.chat',
baseURL,
headers: getHeaders,
compatibility,
});

const createCompletionModel = (
Expand All @@ -118,6 +129,7 @@ export function createOpenAI(
provider: 'openai.completion',
baseURL,
headers: getHeaders,
compatibility,
});

const createEmbeddingModel = (
Expand Down Expand Up @@ -158,6 +170,8 @@ export function createOpenAI(
}

/**
Default OpenAI provider instance.
Default OpenAI provider instance. It uses 'strict' compatibility mode.
*/
export const openai = createOpenAI();
export const openai = createOpenAI({
compatibility: 'strict', // strict for OpenAI API
});

0 comments on commit 4e3c922

Please sign in to comment.