Skip to content

Commit

Permalink
Logprobs Support (#1397)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamStenner committed Apr 23, 2024
1 parent 75afa32 commit 4d683e1
Show file tree
Hide file tree
Showing 23 changed files with 806 additions and 28 deletions.
2 changes: 2 additions & 0 deletions docs/pages/docs/ai-core/openai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ const model = openai.chat('gpt-3.5-turbo', {
// optional likelihood for specific tokens
'50256': -100,
},
logprobs: true, // provides log probabilities for each token
user: 'test-user', // optional unique user identifier
});
```
Expand All @@ -125,6 +126,7 @@ const model = openai.completion('gpt-3.5-turbo-instruct', {
// optional likelihood for specific tokens
'50256': -100,
},
logprobs: true, // provides log probabilities for each token
suffix: 'some text', // optional suffix that comes after a completion of inserted text
user: 'test-user', // optional unique user identifier
});
Expand Down
30 changes: 30 additions & 0 deletions examples/ai-core/src/generate-object/openai-full-json.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { experimental_generateObject } from 'ai';
import { openai } from '@ai-sdk/openai';
import dotenv from 'dotenv';
import { z } from 'zod';

dotenv.config();

async function main() {
const result = await experimental_generateObject({
model: openai('gpt-4-turbo', { logprobs: 2 }),
schema: z.object({
characters: z.array(
z.object({
name: z.string(),
class: z
.string()
.describe('Character class, e.g. warrior, mage, or thief.'),
description: z.string(),
}),
),
}),
mode: 'json',
prompt:
'Generate 3 character descriptions for a fantasy role playing game.',
});

console.log(result);
}

main().catch(console.error);
16 changes: 16 additions & 0 deletions examples/ai-core/src/generate-text/openai-logprobs.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { experimental_generateText } from 'ai';
import { openai } from '@ai-sdk/openai';
import dotenv from 'dotenv';

dotenv.config();

async function main() {
const result = await experimental_generateText({
model: openai('gpt-3.5-turbo', { logprobs: 2 }),
prompt: 'Invent a new holiday and describe its traditions.',
});

console.log(result.logprobs);
}

main().catch(console.error);
49 changes: 49 additions & 0 deletions examples/ai-core/src/stream-object/openai-fullstream.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { openai } from '@ai-sdk/openai';
import { experimental_streamObject } from 'ai';
import dotenv from 'dotenv';
import { z } from 'zod';

dotenv.config();

async function main() {
const result = await experimental_streamObject({
model: openai('gpt-4-turbo', { logprobs: 2 }),
maxTokens: 2000,
schema: z.object({
characters: z.array(
z.object({
name: z.string(),
class: z
.string()
.describe('Character class, e.g. warrior, mage, or thief.'),
description: z.string(),
}),
),
}),
mode: 'json',
prompt:
'Generate 3 character descriptions for a fantasy role playing game.',
});

for await (const part of result.fullStream) {
switch (part.type) {
case 'object':
console.clear();
console.log(part.object);
break;

case 'finish': {
console.log('Finish reason:', part.finishReason);
console.log('Logprobs:', part.logprobs);
console.log('Usage:', part.usage);
break;
}

case 'error':
console.error('Error:', part.error);
break;
}
}
}

main().catch(console.error);
30 changes: 30 additions & 0 deletions examples/ai-core/src/stream-text/openai-fullstream-logprobs.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { openai } from '@ai-sdk/openai';
import { experimental_streamText } from 'ai';
import dotenv from 'dotenv';

dotenv.config();

async function main() {
const result = await experimental_streamText({
model: openai('gpt-3.5-turbo', { logprobs: 2 }),
maxTokens: 512,
temperature: 0.3,
maxRetries: 5,
prompt: 'Invent a new holiday and describe its traditions.',
});

for await (const part of result.fullStream) {
switch (part.type) {
case 'finish': {
console.log('Logprobs:', part.logprobs);
break;
}

case 'error':
console.error('Error:', part.error);
break;
}
}
}

main().catch(console.error);
14 changes: 14 additions & 0 deletions packages/core/core/generate-object/generate-object.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import {
LanguageModelV1,
LanguageModelV1CallWarning,
LanguageModelV1FinishReason,
LanguageModelV1LogProbs,
NoTextGeneratedError,
} from '@ai-sdk/provider';
import { safeParseJSON } from '@ai-sdk/provider-utils';
Expand Down Expand Up @@ -93,6 +94,7 @@ Default and recommended: 'auto' (best mode for the model).
let finishReason: LanguageModelV1FinishReason;
let usage: Parameters<typeof calculateTokenUsage>[0];
let warnings: LanguageModelV1CallWarning[] | undefined;
let logprobs: LanguageModelV1LogProbs | undefined;

switch (mode) {
case 'json': {
Expand Down Expand Up @@ -120,6 +122,7 @@ Default and recommended: 'auto' (best mode for the model).
finishReason = generateResult.finishReason;
usage = generateResult.usage;
warnings = generateResult.warnings;
logprobs = generateResult.logprobs;

break;
}
Expand Down Expand Up @@ -149,6 +152,7 @@ Default and recommended: 'auto' (best mode for the model).
finishReason = generateResult.finishReason;
usage = generateResult.usage;
warnings = generateResult.warnings;
logprobs = generateResult.logprobs;

break;
}
Expand Down Expand Up @@ -188,6 +192,7 @@ Default and recommended: 'auto' (best mode for the model).
finishReason = generateResult.finishReason;
usage = generateResult.usage;
warnings = generateResult.warnings;
logprobs = generateResult.logprobs;

break;
}
Expand All @@ -213,6 +218,7 @@ Default and recommended: 'auto' (best mode for the model).
finishReason,
usage: calculateTokenUsage(usage),
warnings,
logprobs,
});
}

Expand Down Expand Up @@ -240,15 +246,23 @@ Warnings from the model provider (e.g. unsupported settings)
*/
readonly warnings: LanguageModelV1CallWarning[] | undefined;

/**
Logprobs for the completion.
`undefined` if the mode does not support logprobs or if was not enabled
*/
readonly logprobs: LanguageModelV1LogProbs | undefined;

constructor(options: {
object: T;
finishReason: LanguageModelV1FinishReason;
usage: TokenUsage;
warnings: LanguageModelV1CallWarning[] | undefined;
logprobs: LanguageModelV1LogProbs | undefined;
}) {
this.object = options.object;
this.finishReason = options.finishReason;
this.usage = options.usage;
this.warnings = options.warnings;
this.logprobs = options.logprobs;
}
}
63 changes: 63 additions & 0 deletions packages/core/core/generate-object/stream-object.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,67 @@ describe('result.objectStream', () => {
],
);
});

it('should send full stream data', async () => {
const result = await experimental_streamObject({
model: new MockLanguageModelV1({
doStream: async ({ prompt, mode }) => {
assert.deepStrictEqual(mode, { type: 'object-json' });
assert.deepStrictEqual(prompt, [
{
role: 'system',
content:
'JSON schema:\n' +
'{"type":"object","properties":{"content":{"type":"string"}},"required":["content"],"additionalProperties":false,"$schema":"http://json-schema.org/draft-07/schema#"}\n' +
'You MUST answer with a JSON object that matches the JSON schema above.',
},
{ role: 'user', content: [{ type: 'text', text: 'prompt' }] },
]);

return {
stream: convertArrayToReadableStream([
{ type: 'text-delta', textDelta: '{ ' },
{ type: 'text-delta', textDelta: '"content": ' },
{ type: 'text-delta', textDelta: `"Hello, ` },
{ type: 'text-delta', textDelta: `world` },
{ type: 'text-delta', textDelta: `!"` },
{ type: 'text-delta', textDelta: ' }' },
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 2 },
logprobs: [{ token: '-', logprob: 1, topLogprobs: [] }],
},
]),
rawCall: { rawPrompt: 'prompt', rawSettings: { logprobs: 0 } },
};
},
}),
schema: z.object({ content: z.string() }),
mode: 'json',
prompt: 'prompt',
});

assert.deepStrictEqual(
await convertAsyncIterableToArray(result.fullStream),
[
{ type: 'object', object: {} },
{ type: 'object', object: { content: 'Hello, ' } },
{ type: 'object', object: { content: 'Hello, world' } },
{ type: 'object', object: { content: 'Hello, world!' } },
{
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 2, completionTokens: 10, totalTokens: 12 },
logprobs: [
{
token: '-',
logprob: 1,
topLogprobs: [],
},
],
},
],
);
});
});
Loading

0 comments on commit 4d683e1

Please sign in to comment.